From ce85146f0b2c01a0b3f610f457c446b72b6cbee9 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:40:01 +0800 Subject: [PATCH 1/2] fix: improve clarity of dealing with `auth()` during policy generation --- packages/language/src/ast.ts | 1 + packages/runtime/src/validation.ts | 14 + .../validator/expression-validator.ts | 4 +- .../function-invocation-validator.ts | 4 +- .../src/language-server/zmodel-linker.ts | 14 +- .../access-policy/expression-writer.ts | 186 +++++--- .../access-policy/policy-guard-generator.ts | 15 +- .../typescript-expression-transformer.ts | 18 +- packages/schema/src/utils/ast-utils.ts | 36 +- .../tests/generator/expression-writer.test.ts | 407 +++++++++++++++--- tests/integration/test-run/package-lock.json | 6 +- .../tests/with-policy/auth.test.ts | 9 +- .../tests/with-policy/multi-id-fields.test.ts | 85 ++++ 13 files changed, 659 insertions(+), 140 deletions(-) diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index e5abfb58c..b9888eb9d 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -14,6 +14,7 @@ export type ResolvedShape = ExpressionType | AbstractDeclaration; export type ResolvedType = { decl?: ResolvedShape; array?: boolean; + nullable?: boolean; }; export const BinaryExprOperatorPriority: Record = { diff --git a/packages/runtime/src/validation.ts b/packages/runtime/src/validation.ts index ed0ddbfb7..33115f8e9 100644 --- a/packages/runtime/src/validation.ts +++ b/packages/runtime/src/validation.ts @@ -18,3 +18,17 @@ export function validate(validator: z.ZodType, data: unknown) { throw new ValidationError(fromZodError(err as z.ZodError).message); } } + +/** + * Check if the given object has all the given fields, not null or undefined + * @param obj + * @param fields + * @returns + */ +// eslint-disable-next-line @typescript-eslint/no-explicit-any +export function hasAllFields(obj: any, fields: string[]) { + if (typeof obj !== 'object' || !obj) { + return false; + } + return fields.every((f) => obj[f] !== undefined && obj[f] !== null); +} diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index ea15766db..64cdec539 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,6 +1,6 @@ import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; -import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils'; +import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; /** @@ -33,7 +33,7 @@ export default class ExpressionValidator implements AstValidator { private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) { switch (expr.operator) { case 'in': { - if (!isDataModelFieldReference(expr.left)) { + if (!getDataModelFieldReference(expr.left)) { accept('error', 'left operand of "in" must be a field reference', { node: expr.left }); } diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index e5a5c76f0..d02be18f4 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -8,7 +8,7 @@ import { isLiteralExpr, } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; -import { isDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils'; +import { getDataModelFieldReference, isEnumFieldReference } from '../../utils/ast-utils'; import { FILTER_OPERATOR_FUNCTIONS } from '../constants'; import { AstValidator } from '../types'; import { isFromStdlib } from '../utils'; @@ -38,7 +38,7 @@ export default class FunctionInvocationValidator implements AstValidator isDataModel(d) && d.name === 'User'); if (userModel) { - node.$resolvedType = { decl: userModel }; + node.$resolvedType = { decl: userModel, nullable: true }; } } else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) { // future() function is resolved to current model @@ -447,19 +448,24 @@ export class ZModelLinker extends DefaultLinker { //#region Utils private resolveToDeclaredType(node: AstNode, type: FunctionParamType | DataModelFieldType) { + let nullable = false; + if (isDataModelFieldType(type)) { + nullable = type.optional; + } if (type.type) { const mappedType = mapBuiltinTypeToExpressionType(type.type); - node.$resolvedType = { decl: mappedType, array: type.array }; + node.$resolvedType = { decl: mappedType, array: type.array, nullable: nullable }; } else if (type.reference) { node.$resolvedType = { decl: type.reference.ref, array: type.array, + nullable: nullable, }; } } - private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false) { - node.$resolvedType = { decl: type, array }; + private resolveToBuiltinTypeOrDecl(node: AstNode, type: ResolvedShape, array = false, nullable = false) { + node.$resolvedType = { decl: type, array, nullable }; } //#endregion diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index e80fcf43c..5b3e85850 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -17,7 +17,7 @@ import { import { getLiteral, GUARD_FIELD_NAME, PluginError } from '@zenstackhq/sdk'; import { CodeBlockWriter } from 'ts-morph'; import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants'; -import { getIdField, isAuthInvocation } from '../../utils/ast-utils'; +import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; import TypeScriptExpressionTransformer from './typescript-expression-transformer'; import { isFutureExpr } from './utils'; @@ -99,12 +99,17 @@ export class ExpressionWriter { private writeMemberAccess(expr: MemberAccessExpr) { this.block(() => { - // must be a boolean member - this.writeFieldCondition(expr.operand, () => { - this.block(() => { - this.writer.write(`${expr.member.ref?.name}: true`); + if (this.isAuthOrAuthMemberAccess(expr)) { + // member access of `auth()`, generate plain expression + this.guard(() => this.plain(expr), true); + } else { + // must be a boolean member + this.writeFieldCondition(expr.operand, () => { + this.block(() => { + this.writer.write(`${expr.member.ref?.name}: true`); + }); }); - }); + } }); } @@ -190,9 +195,14 @@ export class ExpressionWriter { return false; } - private guard(write: () => void) { + private guard(write: () => void, cast = false) { this.writer.write(`${GUARD_FIELD_NAME}: `); - write(); + if (cast) { + this.writer.write('!!'); + write(); + } else { + write(); + } } private plain(expr: Expression) { @@ -211,12 +221,9 @@ export class ExpressionWriter { // compile down to a plain expression this.block(() => { this.guard(() => { - this.plain(expr.left); - this.writer.write(' ' + operator + ' '); - this.plain(expr.right); + this.plain(expr); }); }); - return; } @@ -242,65 +249,105 @@ export class ExpressionWriter { } as ReferenceExpr; } - // if the operand refers to auth(), need to build a guard to avoid - // using undefined user as filter (which means no filter to Prisma) - // if auth() evaluates falsy, just treat the condition as false - if (this.isAuthOrAuthMemberAccess(operand)) { - this.writer.write(`!user ? { ${GUARD_FIELD_NAME}: false } : `); + // guard member access of `auth()` with null check + if (this.isAuthOrAuthMemberAccess(operand) && !fieldAccess.$resolvedType?.nullable) { + this.writer.write( + `(${this.plainExprBuilder.transform(operand)} == null) ? { ${GUARD_FIELD_NAME}: ${ + // auth().x != user.x is true when auth().x is null and user is not nullable + // other expressions are evaluated to false when null is involved + operator === '!=' ? 'true' : 'false' + } } : ` + ); } - this.block(() => { - this.writeFieldCondition(fieldAccess, () => { - this.block( - () => { + this.block( + () => { + this.writeFieldCondition(fieldAccess, () => { + this.block(() => { const dataModel = this.isModelTyped(fieldAccess); - if (dataModel) { - const idField = getIdField(dataModel); - if (!idField) { + if (dataModel && isAuthInvocation(operand)) { + // right now this branch only serves comparison with `auth`, like + // @@allow('all', owner == auth()) + + const idFields = getIdFields(dataModel); + if (!idFields || idFields.length === 0) { throw new PluginError(`Data model ${dataModel.name} does not have an id field`); } - // comparing with an object, convert to "id" comparison instead - this.writer.write(`${idField.name}: `); + + if (operator !== '==' && operator !== '!=') { + throw new PluginError('Only == and != operators are allowed'); + } + + if (!isThisExpr(fieldAccess)) { + this.writer.writeLine(operator === '==' ? 'is:' : 'isNot:'); + const fieldIsNullable = !!fieldAccess.$resolvedType?.nullable; + if (fieldIsNullable) { + // if field is nullable, we can generate "null" check condition + this.writer.write(`(user == null) ? null : `); + } + } + this.block(() => { - this.writeOperator(operator, () => { - // operand ? operand.field : null - this.writer.write('('); - this.plain(operand); - this.writer.write(' ? '); - this.plain(operand); - this.writer.write(`.${idField.name}`); - this.writer.write(' : null'); - this.writer.write(')'); + idFields.forEach((idField, idx) => { + const writeIdsCheck = () => { + // id: user.id + this.writer.write(`${idField.name}:`); + this.plain(operand); + this.writer.write(`.${idField.name}`); + if (idx !== idFields.length - 1) { + this.writer.write(','); + } + }; + + if (isThisExpr(fieldAccess) && operator === '!=') { + // wrap a not + this.writer.writeLine('NOT:'); + this.block(() => writeIdsCheck()); + } else { + writeIdsCheck(); + } }); }); } else { - this.writeOperator(operator, () => { + this.writeOperator(operator, fieldAccess, () => { this.plain(operand); }); } - }, - // "this" expression is compiled away (to .id access), so we should - // avoid generating a new layer - !isThisExpr(fieldAccess) - ); - }); - }); + }, !isThisExpr(fieldAccess)); + }); + }, + // "this" expression is compiled away (to .id access), so we should + // avoid generating a new layer + !isThisExpr(fieldAccess) + ); } private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } - private writeOperator(operator: ComparisonOperator, writeOperand: () => void) { - if (operator === '!=') { - // wrap a 'not' - this.writer.write('not: '); - this.block(() => { - this.writeOperator('==', writeOperand); - }); - } else { - this.writer.write(`${this.mapOperator(operator)}: `); + private writeOperator(operator: ComparisonOperator, fieldAccess: Expression, writeOperand: () => void) { + if (isDataModel(fieldAccess.$resolvedType?.decl)) { + if (operator === '==') { + this.writer.write('is: '); + } else if (operator === '!=') { + this.writer.write('isNot: '); + } else { + throw new PluginError('Only == and != operators are allowed for data model comparison'); + } writeOperand(); + } else { + if (operator === '!=') { + // wrap a 'not' + this.writer.write('not: '); + this.block(() => { + this.writer.write(`${this.mapOperator(operator)}: `); + writeOperand(); + }); + } else { + this.writer.write(`${this.mapOperator(operator)}: `); + writeOperand(); + } } } @@ -414,10 +461,37 @@ export class ExpressionWriter { } private writeLogical(expr: BinaryExpr, operator: '&&' | '||') { - this.block(() => { - this.writer.write(`${operator === '&&' ? 'AND' : 'OR'}: `); - this.writeExprList([expr.left, expr.right]); - }); + // TODO: do we need short-circuit for logical operators? + + if (operator === '&&') { + // // && short-circuit: left && right -> left ? right : { zenstack_guard: false } + // if (!this.hasFieldAccess(expr.left)) { + // this.plain(expr.left); + // this.writer.write(' ? '); + // this.write(expr.right); + // this.writer.write(' : '); + // this.block(() => this.guard(() => this.writer.write('false'))); + // } else { + this.block(() => { + this.writer.write('AND:'); + this.writeExprList([expr.left, expr.right]); + }); + // } + } else { + // // || short-circuit: left || right -> left ? { zenstack_guard: true } : right + // if (!this.hasFieldAccess(expr.left)) { + // this.plain(expr.left); + // this.writer.write(' ? '); + // this.block(() => this.guard(() => this.writer.write('true'))); + // this.writer.write(' : '); + // this.write(expr.right); + // } else { + this.block(() => { + this.writer.write('OR:'); + this.writeExprList([expr.left, expr.right]); + }); + // } + } } private writeUnary(expr: UnaryExpr) { diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 9f0fcd5af..06a70336b 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -21,7 +21,7 @@ import path from 'path'; import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; -import { analyzePolicies, getIdField } from '../../utils/ast-utils'; +import { analyzePolicies, getIdFields } from '../../utils/ast-utils'; import { ALL_OPERATION_KINDS, getDefaultOutputFolder, RUNTIME_PACKAGE } from '../plugin-utils'; import { ExpressionWriter } from './expression-writer'; import { isFutureExpr } from './utils'; @@ -42,9 +42,8 @@ export default class PolicyGenerator { const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); sf.addImportDeclaration({ - namedImports: [{ name: 'QueryContext' }], + namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }], moduleSpecifier: `${RUNTIME_PACKAGE}`, - isTypeOnly: true, }); sf.addImportDeclaration({ @@ -329,13 +328,17 @@ export default class PolicyGenerator { if (!userModel) { throw new PluginError('User model not found'); } - const userIdField = getIdField(userModel); - if (!userIdField) { + const userIdFields = getIdFields(userModel); + if (!userIdFields || userIdFields.length === 0) { throw new PluginError('User model does not have an id field'); } // normalize user to null to avoid accidentally use undefined in filter - func.addStatements(`const user = context.user ?? null;`); + func.addStatements( + `const user = hasAllFields(context.user, [${userIdFields + .map((f) => "'" + f.name + "'") + .join(', ')}]) ? context.user : null;` + ); } // r = ; diff --git a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts index 961b3028f..cb6dfba5e 100644 --- a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts +++ b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts @@ -1,5 +1,6 @@ import { ArrayExpr, + BinaryExpr, Expression, InvocationExpr, isEnumField, @@ -9,6 +10,7 @@ import { NullExpr, ReferenceExpr, ThisExpr, + UnaryExpr, } from '@zenstackhq/language/ast'; import { PluginError } from '@zenstackhq/sdk'; import { isAuthInvocation } from '../../utils/ast-utils'; @@ -53,6 +55,12 @@ export default class TypeScriptExpressionTransformer { case MemberAccessExpr: return this.memberAccess(expr as MemberAccessExpr); + case UnaryExpr: + return this.unary(expr as UnaryExpr); + + case BinaryExpr: + return this.binary(expr as BinaryExpr); + default: throw new PluginError(`Unsupported expression type: ${expr.$type}`); } @@ -78,7 +86,7 @@ export default class TypeScriptExpressionTransformer { return expr.member.ref.name; } else { // normalize field access to null instead of undefined to avoid accidentally use undefined in filter - return `(${this.transform(expr.operand)} ? ${this.transform(expr.operand)}.${expr.member.ref.name} : null)`; + return `(${this.transform(expr.operand)}?.${expr.member.ref.name} ?? null)`; } } @@ -124,4 +132,12 @@ export default class TypeScriptExpressionTransformer { return expr.value.toString(); } } + + private unary(expr: UnaryExpr): string { + return `(${expr.operator} ${this.transform(expr.operand)})`; + } + + private binary(expr: BinaryExpr): string { + return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`; + } } diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 7452d50fa..5456a2670 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,7 +1,9 @@ import { DataModel, DataModelAttribute, + DataModelField, Expression, + isArrayExpr, isDataModel, isDataModelField, isEnumField, @@ -9,6 +11,7 @@ import { isMemberAccessExpr, isReferenceExpr, Model, + ReferenceExpr, } from '@zenstackhq/language/ast'; import { PolicyOperationKind } from '@zenstackhq/runtime'; import { getLiteral } from '@zenstackhq/sdk'; @@ -100,8 +103,25 @@ export const VALIDATION_ATTRIBUTES = [ '@lte', ]; -export function getIdField(dataModel: DataModel) { - return dataModel.fields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id')); +export function getIdFields(dataModel: DataModel) { + const fieldLevelId = dataModel.fields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id')); + if (fieldLevelId) { + return [fieldLevelId]; + } else { + // get model level @@id attribute + const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + if (modelIdAttr) { + // get fields referenced in the attribute: @@id([field1, field2]]) + if (!isArrayExpr(modelIdAttr.args[0].value)) { + return []; + } + const argValue = modelIdAttr.args[0].value; + return argValue.items + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .map((expr) => expr.target.ref as DataModelField); + } + } + return []; } export function isAuthInvocation(expr: Expression) { @@ -112,12 +132,12 @@ export function isEnumFieldReference(expr: Expression) { return isReferenceExpr(expr) && isEnumField(expr.target.ref); } -export function isDataModelFieldReference(expr: Expression): boolean { - if (isReferenceExpr(expr)) { - return isDataModelField(expr.target.ref); - } else if (isMemberAccessExpr(expr)) { - return true; +export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { + if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { + return expr.target.ref; + } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + return expr.member.ref; } else { - return false; + return undefined; } } diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index 8476418b0..6f50d50d3 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -119,13 +119,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `!user ? - { zenstack_guard: false } : - { - id: { - equals: (user ? user.id: null) - } - }` + `(user == null) ? { zenstack_guard: false } : { id: user.id }` ); await check( @@ -137,15 +131,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `!user ? - { zenstack_guard: false } : - { - id: { - not: { - equals: (user ? user.id: null) - } - } - }` + `(user == null) ? { zenstack_guard: true } : { NOT: { id: user.id } }` ); await check( @@ -536,33 +522,113 @@ describe('Expression Writer Tests', () => { ); }); - it('auth check', async () => { + it('auth null check', async () => { await check( ` - model User { id String @id } + model User { + id String @id + } + model Test { id String @id - @@deny('all', auth() == null) + @@allow('all', auth() == null) } `, (model) => model.attributes[0].args[1].value, - `{ ${GUARD_FIELD_NAME}: user == null }` + `{ zenstack_guard: (user == null) }`, + '{ id: "1" }' ); await check( ` - model User { id String @id } + model User { + x String + y String + @@id([x, y]) + } + + model Test { + id String @id + @@allow('all', auth() == null) + } + `, + (model) => model.attributes[0].args[1].value, + `{ zenstack_guard: (user == null) }`, + '{ x: "1", y: "2" }' + ); + + await check( + ` + model User { + id String @id + } + model Test { id String @id @@allow('all', auth() != null) } `, (model) => model.attributes[0].args[1].value, - `{ ${GUARD_FIELD_NAME}: user != null }` + `{ zenstack_guard: (user != null) }`, + '{ id: "1" }' + ); + + await check( + ` + model User { + x String + y String + @@id([x, y]) + } + + model Test { + id String @id + @@allow('all', auth() != null) + } + `, + (model) => model.attributes[0].args[1].value, + `{ zenstack_guard: (user != null) }`, + '{ x: "1", y: "2" }' + ); + }); + + it('auth boolean field check', async () => { + await check( + ` + model User { + id String @id + admin Boolean + } + + model Test { + id String @id + @@allow('all', auth().admin) + } + `, + (model) => model.attributes[0].args[1].value, + `{ zenstack_guard: !!(user?.admin ?? null) }`, + '{ id: "1", admin: true }' + ); + + await check( + ` + model User { + id String @id + admin Boolean + } + + model Test { + id String @id + @@deny('all', !auth().admin) + } + `, + (model) => model.attributes[0].args[1].value, + `{ NOT: { zenstack_guard: !!(user?.admin ?? null) } }`, + '{ id: "1", admin: true }' ); }); - it('auth check against field', async () => { + it('auth check against field single id', async () => { await check( ` model User { @@ -578,16 +644,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `!user ? - { zenstack_guard : false } : - { - owner: { - id: { - equals: (user ? user.id : null) - } - } - } - ` + `(user==null) ? { zenstack_guard: false } : { owner: { is: { id : user.id } } }` ); await check( @@ -605,17 +662,12 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `!user ? - { zenstack_guard : false } : - { - owner: { - id: { - not: { - equals: (user ? user.id : null) - } - } - } - }` + `(user==null) ? { zenstack_guard: true } : + { + owner: { + isNot: { id: user.id } + } + }` ); await check( @@ -633,15 +685,261 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `!user ? + `((user?.id??null)==null) ? { zenstack_guard : false } : - { - owner: { - id: { - equals: (user ? user.id : null) - } - } - }` + { owner: { id: { equals: (user?.id ?? null) } } }` + ); + }); + + it('auth check against field multi-id', async () => { + await check( + ` + model User { + x String + y String + t Test? + @@id([x, y]) + } + + model Test { + id String @id + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() == owner) + } + `, + (model) => model.attributes[1].args[1].value, + `(user==null) ? + { zenstack_guard: false } : + { owner: { is: { x: user.x, y: user.y } } }`, + '{ x: "1", y: "2" }' + ); + + await check( + ` + model User { + x String + y String + t Test? + @@id([x, y]) + } + + model Test { + id String @id + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() != owner) + } + `, + (model) => model.attributes[1].args[1].value, + `(user==null) ? + { zenstack_guard: true } : + { owner: { isNot: { x: user.x, y: user.y } } }`, + '{ x: "1", y: "2" }' + ); + + await check( + ` + model User { + x String + y String + t Test? + @@id([x, y]) + } + + model Test { + id String @id + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth().x == owner.x && auth().y == owner.y) + } + `, + (model) => model.attributes[1].args[1].value, + `{ + AND: [ + ((user?.x??null)==null) ? { zenstack_guard: false } : { owner: { x: { equals: (user?.x ?? null) } } }, + ((user?.y??null)==null) ? { zenstack_guard: false } : { owner: { y: { equals: (user?.y ?? null) } } } + ] + }`, + '{ x: "1", y: "2" }' + ); + }); + + it('auth check against nullable field', async () => { + await check( + ` + model User { + id String @id + t Test? + } + + model Test { + id String @id + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? @unique + @@allow('all', auth() == owner) + } + `, + (model) => model.attributes[0].args[1].value, + `{ + owner: { + is: (user == null) ? null : { id: user.id } + } + }` + ); + + await check( + ` + model User { + id String @id + t Test? + } + + model Test { + id String @id + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? @unique + @@deny('all', auth() != owner) + } + `, + (model) => model.attributes[0].args[1].value, + `{ + owner: { + isNot: (user == null) ? null : { id: user.id } + } + }` + ); + + await check( + ` + model User { + id String @id + t Test? + } + + model Test { + id String @id + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? @unique + @@allow('all', auth().id == owner.id) + } + `, + (model) => model.attributes[0].args[1].value, + `((user?.id??null)==null) ? { zenstack_guard: false } : { owner: { id: { equals: (user?.id ?? null) } } }` + ); + }); + + it('auth check short-circuit [TBD]', async () => { + await check( + ` + model User { + id String @id + t Test? + } + + model Test { + id String @id + owner User @relation(fields: [ownerId], references: [id]) + ownerId String @unique + value Int + @@allow('all', auth() != null && auth().id == owner.id && value > 0) + } + `, + (model) => model.attributes[0].args[1].value, + `{ + AND: [ + { + AND: [ + { zenstack_guard: (user!=null) }, + ((user?.id??null)==null) ? {zenstack_guard:false} : { owner: { id: { equals: (user?.id??null) } } } + ] + }, + { value: { gt: 0 } } + ] + }` + ); + + await check( + ` + model User { + id String @id + t Test? + } + + model Test { + id String @id + owner User @relation(fields: [ownerId], references: [id]) + ownerId String @unique + value Int + @@deny('all', auth() == null || auth().id != owner.id || value <= 0) + } + `, + (model) => model.attributes[0].args[1].value, + `{ + OR: [ + { + OR: [ + { zenstack_guard:(user==null) }, + ((user?.id??null)==null) ? {zenstack_guard:true} : { owner : { id: { not: { equals: (user?.id??null) } } } } + ] + }, + { value: { lte: 0 } } + ] + }` + ); + }); + + it('relation field null check', async () => { + await check( + ` + model M { + id String @id + s String? + t Test @relation(fields: [tId], references: [id]) + tId String @unique + } + + model Test { + id String @id + m M? + @@allow('all', m == null || m.s == null) + } + `, + (model) => model.attributes[0].args[1].value, + ` + { + OR: [{ m: { equals: null } }, { m: { s: { equals: null } } }] + } + ` + ); + + await check( + ` + model M { + id String @id + s String? + t Test @relation(fields: [tId], references: [id]) + tId String @unique + } + + model Test { + id String @id + m M? + @@deny('all', m != null || m.s != null) + } + `, + (model) => model.attributes[0].args[1].value, + ` + { + OR: [{ m: { not: { equals: null } } }, { m: { s: { not: { equals: null } } } }] + } + ` ); }); @@ -836,7 +1134,7 @@ describe('Expression Writer Tests', () => { }); }); -async function check(schema: string, getExpr: (model: DataModel) => Expression, expected: string) { +async function check(schema: string, getExpr: (model: DataModel) => Expression, expected: string, userInit?: string) { if (!schema.includes('datasource ')) { schema = ` @@ -860,7 +1158,7 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, // inject user variable sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'user', initializer: '{ id: "user1" }' }], + declarations: [{ name: 'user', initializer: userInit ?? '{ id: "user1" }' }], }); // inject enums @@ -894,6 +1192,7 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, sf.formatText(); await project.save(); + console.log('Source saved:', sourcePath); if (project.getPreEmitDiagnostics().length > 0) { for (const d of project.getPreEmitDiagnostics()) { diff --git a/tests/integration/test-run/package-lock.json b/tests/integration/test-run/package-lock.json index 14f918130..2cbb0df60 100644 --- a/tests/integration/test-run/package-lock.json +++ b/tests/integration/test-run/package-lock.json @@ -173,7 +173,7 @@ "colors": "1.4.0", "commander": "^8.3.0", "cuid": "^2.1.8", - "langium": "1.0.1", + "langium": "1.1.0", "mixpanel": "^0.17.0", "node-machine-id": "^1.1.12", "ora": "^5.4.1", @@ -221,6 +221,7 @@ "ts-node": "^10.9.1", "tsc-alias": "^1.7.0", "typescript": "^4.8.4", + "vitest": "^0.29.7", "vsce": "^2.13.0" }, "engines": { @@ -425,7 +426,7 @@ "eslint": "^8.27.0", "eslint-plugin-jest": "^27.1.7", "jest": "^29.2.1", - "langium": "1.0.1", + "langium": "1.1.0", "langium-cli": "^1.0.0", "mixpanel": "^0.17.0", "node-machine-id": "^1.1.12", @@ -444,6 +445,7 @@ "tsc-alias": "^1.7.0", "typescript": "^4.8.4", "uuid": "^9.0.0", + "vitest": "^0.29.7", "vsce": "^2.13.0", "vscode-jsonrpc": "^8.0.2", "vscode-languageclient": "^8.0.2", diff --git a/tests/integration/tests/with-policy/auth.test.ts b/tests/integration/tests/with-policy/auth.test.ts index ac74f451b..57c9fff08 100644 --- a/tests/integration/tests/with-policy/auth.test.ts +++ b/tests/integration/tests/with-policy/auth.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; -describe('With Policy:undefined user', () => { +describe('With Policy: auth() test', () => { let origDir: string; const suite = 'undefined-user'; @@ -182,13 +182,12 @@ describe('With Policy:undefined user', () => { const db = withPolicy(); await expect(db.user.create({ data: { id: 'user1', role: 'USER' } })).toResolveTruthy(); await expect(db.post.create({ data: { id: '1', title: 'abc', authorId: 'user1' } })).toResolveTruthy(); - await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - const authDb = withPolicy({ role: 'USER' }); - await expect(db.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); + const authDb = withPolicy({ id: 'user1', role: 'USER' }); + await expect(authDb.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toBeRejectedByPolicy(); - const authDb1 = withPolicy({ role: 'ADMIN' }); + const authDb1 = withPolicy({ id: 'user2', role: 'ADMIN' }); await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy(); }); }); diff --git a/tests/integration/tests/with-policy/multi-id-fields.test.ts b/tests/integration/tests/with-policy/multi-id-fields.test.ts index f9984f98f..156b9e2ac 100644 --- a/tests/integration/tests/with-policy/multi-id-fields.test.ts +++ b/tests/integration/tests/with-policy/multi-id-fields.test.ts @@ -68,4 +68,89 @@ describe('With Policy: multiple id fields', () => { }) ).toResolveTruthy(); }); + + it('multi-id auth', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + x String + y String + m M? + n N? + p P? + q Q? + @@id([x, y]) + @@allow('all', true) + } + + model M { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() == owner) + } + + model N { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth().x == owner.x && auth().y == owner.y) + } + + model P { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() != owner) + } + + model Q { + id String @id @default(cuid()) + owner User @relation(fields: [ownerX, ownerY], references: [x, y]) + ownerX String + ownerY String + @@unique([ownerX, ownerY]) + @@allow('all', auth() != null) + } + ` + ); + + await prisma.user.create({ data: { x: '1', y: '1' } }); + await prisma.user.create({ data: { x: '1', y: '2' } }); + + const anonDb = withPolicy({}); + + await expect( + anonDb.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } }) + ).toBeRejectedByPolicy(); + await expect( + anonDb.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) + ).toBeRejectedByPolicy(); + await expect( + anonDb.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } }) + ).toBeRejectedByPolicy(); + await expect( + anonDb.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) + ).toBeRejectedByPolicy(); + + const db = withPolicy({ x: '1', y: '1' }); + + await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toBeRejectedByPolicy(); + await expect(db.m.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toResolveTruthy(); + await expect(db.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toBeRejectedByPolicy(); + await expect(db.n.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toResolveTruthy(); + await expect(db.p.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } })).toBeRejectedByPolicy(); + await expect(db.p.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); + + await expect( + withPolicy(undefined).q.create({ data: { owner: { connect: { x_y: { x: '1', y: '1' } } } } }) + ).toBeRejectedByPolicy(); + await expect(db.q.create({ data: { owner: { connect: { x_y: { x: '1', y: '2' } } } } })).toResolveTruthy(); + }); }); From 803399203554053b7f1cfe50ba56dfede3a19af1 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 28 Mar 2023 08:32:32 +0800 Subject: [PATCH 2/2] fix tests --- .../schema/src/plugins/access-policy/expression-writer.ts | 2 +- packages/schema/tests/generator/expression-writer.test.ts | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 5b3e85850..03fe11fb1 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -341,7 +341,7 @@ export class ExpressionWriter { // wrap a 'not' this.writer.write('not: '); this.block(() => { - this.writer.write(`${this.mapOperator(operator)}: `); + this.writer.write(`${this.mapOperator('==')}: `); writeOperand(); }); } else { diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index 6f50d50d3..157d4916e 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -914,7 +914,7 @@ describe('Expression Writer Tests', () => { (model) => model.attributes[0].args[1].value, ` { - OR: [{ m: { equals: null } }, { m: { s: { equals: null } } }] + OR: [{ m: { is: null } }, { m: { s: { equals: null } } }] } ` ); @@ -937,7 +937,7 @@ describe('Expression Writer Tests', () => { (model) => model.attributes[0].args[1].value, ` { - OR: [{ m: { not: { equals: null } } }, { m: { s: { not: { equals: null } } } }] + OR: [{ m: { isNot: null } }, { m: { s: { not: { equals: null } } } }] } ` );