Skip to content

fix: improve clarity of dealing with auth() during policy generation #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/language/src/ast.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export type ResolvedShape = ExpressionType | AbstractDeclaration;
export type ResolvedType = {
decl?: ResolvedShape;
array?: boolean;
nullable?: boolean;
};

export const BinaryExprOperatorPriority: Record<BinaryExpr['operator'], number> = {
Expand Down
14 changes: 14 additions & 0 deletions packages/runtime/src/validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,17 @@ export function validate(validator: z.ZodType, data: unknown) {
throw new ValidationError(fromZodError(err as z.ZodError).message);
}
}

/**
* Check if the given object has all the given fields, not null or undefined
* @param obj
* @param fields
* @returns
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function hasAllFields(obj: any, fields: string[]) {
if (typeof obj !== 'object' || !obj) {
return false;
}
return fields.every((f) => obj[f] !== undefined && obj[f] !== null);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast';
import { ValidationAcceptor } from 'langium';
import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils';
import { AstValidator } from '../types';

/**
Expand Down Expand Up @@ -33,7 +33,7 @@ export default class ExpressionValidator implements AstValidator<Expression> {
private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) {
switch (expr.operator) {
case 'in': {
if (!isDataModelFieldReference(expr.left)) {
if (!getDataModelFieldReference(expr.left)) {
accept('error', 'left operand of "in" must be a field reference', { node: expr.left });
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
isLiteralExpr,
} from '@zenstackhq/language/ast';
import { ValidationAcceptor } from 'langium';
import { isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
import { getDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils';
import { FILTER_OPERATOR_FUNCTIONS } from '../constants';
import { AstValidator } from '../types';
import { isFromStdlib } from '../utils';
Expand Down Expand Up @@ -38,7 +38,7 @@ export default class FunctionInvocationValidator implements AstValidator<Express
// first argument must refer to a model field
const firstArg = expr.args?.[0]?.value;
if (firstArg) {
if (!isDataModelFieldReference(firstArg)) {
if (!getDataModelFieldReference(firstArg)) {
accept('error', 'first argument must be a field reference', { node: firstArg });
}
}
Expand Down
14 changes: 10 additions & 4 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import {
isArrayExpr,
isDataModel,
isDataModelField,
isDataModelFieldType,
isReferenceExpr,
LiteralExpr,
MemberAccessExpr,
Expand Down Expand Up @@ -249,7 +250,7 @@ export class ZModelLinker extends DefaultLinker {
const model = getContainingModel(node);
const userModel = model?.declarations.find((d) => isDataModel(d) && d.name === 'User');
if (userModel) {
node.$resolvedType = { decl: userModel };
node.$resolvedType = { decl: userModel, nullable: true };
}
} else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) {
// future() function is resolved to current model
Expand Down Expand Up @@ -447,19 +448,24 @@ export class ZModelLinker extends DefaultLinker {
//#region Utils

private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType) {
let nullable = false;
if (isDataModelFieldType(type)) {
nullable = type.optional;
}
if (type.type) {
const mappedType = mapBuiltinTypeToExpressionType(type.type);
node.$resolvedType = { decl: mappedType, array: type.array };
node.$resolvedType = { decl: mappedType, array: type.array, nullable: nullable };
} else if (type.reference) {
node.$resolvedType = {
decl: type.reference.ref,
array: type.array,
nullable: nullable,
};
}
}

private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false) {
node.$resolvedType = { decl: type, array };
private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false, nullable = false) {
node.$resolvedType = { decl: type, array, nullable };
}

//#endregion
Expand Down
186 changes: 130 additions & 56 deletions packages/schema/src/plugins/access-policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import {
import { getLiteral, GUARD_FIELD_NAME, PluginError } from '@zenstackhq/sdk';
import { CodeBlockWriter } from 'ts-morph';
import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants';
import { getIdField, isAuthInvocation } from '../../utils/ast-utils';
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
import TypeScriptExpressionTransformer from './typescript-expression-transformer';
import { isFutureExpr } from './utils';

Expand Down Expand Up @@ -99,12 +99,17 @@ export class ExpressionWriter {

private writeMemberAccess(expr: MemberAccessExpr) {
this.block(() => {
// must be a boolean member
this.writeFieldCondition(expr.operand, () => {
this.block(() => {
this.writer.write(`${expr.member.ref?.name}: true`);
if (this.isAuthOrAuthMemberAccess(expr)) {
// member access of `auth()`, generate plain expression
this.guard(() => this.plain(expr), true);
} else {
// must be a boolean member
this.writeFieldCondition(expr.operand, () => {
this.block(() => {
this.writer.write(`${expr.member.ref?.name}: true`);
});
});
});
}
});
}

Expand Down Expand Up @@ -190,9 +195,14 @@ export class ExpressionWriter {
return false;
}

private guard(write: () => void) {
private guard(write: () => void, cast = false) {
this.writer.write(`${GUARD_FIELD_NAME}: `);
write();
if (cast) {
this.writer.write('!!');
write();
} else {
write();
}
}

private plain(expr: Expression) {
Expand All @@ -211,12 +221,9 @@ export class ExpressionWriter {
// compile down to a plain expression
this.block(() => {
this.guard(() => {
this.plain(expr.left);
this.writer.write(' ' + operator + ' ');
this.plain(expr.right);
this.plain(expr);
});
});

return;
}

Expand All @@ -242,65 +249,105 @@ export class ExpressionWriter {
} as ReferenceExpr;
}

// if the operand refers to auth(), need to build a guard to avoid
// using undefined user as filter (which means no filter to Prisma)
// if auth() evaluates falsy, just treat the condition as false
if (this.isAuthOrAuthMemberAccess(operand)) {
this.writer.write(`!user ? { ${GUARD_FIELD_NAME}: false } : `);
// guard member access of `auth()` with null check
if (this.isAuthOrAuthMemberAccess(operand) && !fieldAccess.$resolvedType?.nullable) {
this.writer.write(
`(${this.plainExprBuilder.transform(operand)} == null) ? { ${GUARD_FIELD_NAME}: ${
// auth().x != user.x is true when auth().x is null and user is not nullable
// other expressions are evaluated to false when null is involved
operator === '!=' ? 'true' : 'false'
} } : `
);
}

this.block(() => {
this.writeFieldCondition(fieldAccess, () => {
this.block(
() => {
this.block(
() => {
this.writeFieldCondition(fieldAccess, () => {
this.block(() => {
const dataModel = this.isModelTyped(fieldAccess);
if (dataModel) {
const idField = getIdField(dataModel);
if (!idField) {
if (dataModel && isAuthInvocation(operand)) {
// right now this branch only serves comparison with `auth`, like
// @@allow('all', owner == auth())

const idFields = getIdFields(dataModel);
if (!idFields || idFields.length === 0) {
throw new PluginError(`Data model ${dataModel.name} does not have an id field`);
}
// comparing with an object, convert to "id" comparison instead
this.writer.write(`${idField.name}: `);

if (operator !== '==' && operator !== '!=') {
throw new PluginError('Only == and != operators are allowed');
}

if (!isThisExpr(fieldAccess)) {
this.writer.writeLine(operator === '==' ? 'is:' : 'isNot:');
const fieldIsNullable = !!fieldAccess.$resolvedType?.nullable;
if (fieldIsNullable) {
// if field is nullable, we can generate "null" check condition
this.writer.write(`(user == null) ? null : `);
}
}

this.block(() => {
this.writeOperator(operator, () => {
// operand ? operand.field : null
this.writer.write('(');
this.plain(operand);
this.writer.write(' ? ');
this.plain(operand);
this.writer.write(`.${idField.name}`);
this.writer.write(' : null');
this.writer.write(')');
idFields.forEach((idField, idx) => {
const writeIdsCheck = () => {
// id: user.id
this.writer.write(`${idField.name}:`);
this.plain(operand);
this.writer.write(`.${idField.name}`);
if (idx !== idFields.length - 1) {
this.writer.write(',');
}
};

if (isThisExpr(fieldAccess) && operator === '!=') {
// wrap a not
this.writer.writeLine('NOT:');
this.block(() => writeIdsCheck());
} else {
writeIdsCheck();
}
});
});
} else {
this.writeOperator(operator, () => {
this.writeOperator(operator, fieldAccess, () => {
this.plain(operand);
});
}
},
// "this" expression is compiled away (to .id access), so we should
// avoid generating a new layer
!isThisExpr(fieldAccess)
);
});
});
}, !isThisExpr(fieldAccess));
});
},
// "this" expression is compiled away (to .id access), so we should
// avoid generating a new layer
!isThisExpr(fieldAccess)
);
}

private isAuthOrAuthMemberAccess(expr: Expression) {
return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand));
}

private writeOperator(operator: ComparisonOperator, writeOperand: () => void) {
if (operator === '!=') {
// wrap a 'not'
this.writer.write('not: ');
this.block(() => {
this.writeOperator('==', writeOperand);
});
} else {
this.writer.write(`${this.mapOperator(operator)}: `);
private writeOperator(operator: ComparisonOperator, fieldAccess: Expression, writeOperand: () => void) {
if (isDataModel(fieldAccess.$resolvedType?.decl)) {
if (operator === '==') {
this.writer.write('is: ');
} else if (operator === '!=') {
this.writer.write('isNot: ');
} else {
throw new PluginError('Only == and != operators are allowed for data model comparison');
}
writeOperand();
} else {
if (operator === '!=') {
// wrap a 'not'
this.writer.write('not: ');
this.block(() => {
this.writer.write(`${this.mapOperator('==')}: `);
writeOperand();
});
} else {
this.writer.write(`${this.mapOperator(operator)}: `);
writeOperand();
}
}
}

Expand Down Expand Up @@ -414,10 +461,37 @@ export class ExpressionWriter {
}

private writeLogical(expr: BinaryExpr, operator: '&&' | '||') {
this.block(() => {
this.writer.write(`${operator === '&&' ? 'AND' : 'OR'}: `);
this.writeExprList([expr.left, expr.right]);
});
// TODO: do we need short-circuit for logical operators?

if (operator === '&&') {
// // && short-circuit: left && right -> left ? right : { zenstack_guard: false }
// if (!this.hasFieldAccess(expr.left)) {
// this.plain(expr.left);
// this.writer.write(' ? ');
// this.write(expr.right);
// this.writer.write(' : ');
// this.block(() => this.guard(() => this.writer.write('false')));
// } else {
this.block(() => {
this.writer.write('AND:');
this.writeExprList([expr.left, expr.right]);
});
// }
} else {
// // || short-circuit: left || right -> left ? { zenstack_guard: true } : right
// if (!this.hasFieldAccess(expr.left)) {
// this.plain(expr.left);
// this.writer.write(' ? ');
// this.block(() => this.guard(() => this.writer.write('true')));
// this.writer.write(' : ');
// this.write(expr.right);
// } else {
this.block(() => {
this.writer.write('OR:');
this.writeExprList([expr.left, expr.right]);
});
// }
}
}

private writeUnary(expr: UnaryExpr) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import path from 'path';
import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind } from 'ts-morph';
import { name } from '.';
import { isFromStdlib } from '../../language-server/utils';
import { analyzePolicies, getIdField } from '../../utils/ast-utils';
import { analyzePolicies, getIdFields } from '../../utils/ast-utils';
import { ALL_OPERATION_KINDS, getDefaultOutputFolder, RUNTIME_PACKAGE } from '../plugin-utils';
import { ExpressionWriter } from './expression-writer';
import { isFutureExpr } from './utils';
Expand All @@ -42,9 +42,8 @@ export default class PolicyGenerator {
const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true });

sf.addImportDeclaration({
namedImports: [{ name: 'QueryContext' }],
namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }],
moduleSpecifier: `${RUNTIME_PACKAGE}`,
isTypeOnly: true,
});

sf.addImportDeclaration({
Expand Down Expand Up @@ -329,13 +328,17 @@ export default class PolicyGenerator {
if (!userModel) {
throw new PluginError('User model not found');
}
const userIdField = getIdField(userModel);
if (!userIdField) {
const userIdFields = getIdFields(userModel);
if (!userIdFields || userIdFields.length === 0) {
throw new PluginError('User model does not have an id field');
}

// normalize user to null to avoid accidentally use undefined in filter
func.addStatements(`const user = context.user ?? null;`);
func.addStatements(
`const user = hasAllFields(context.user, [${userIdFields
.map((f) => "'" + f.name + "'")
.join(', ')}]) ? context.user : null;`
);
}

// r = <guard object>;
Expand Down
Loading