diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index a820620d1..e25563cd2 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -16,7 +16,7 @@ import { isReferenceExpr, } from '@zenstackhq/language/ast'; import { isFutureExpr, resolved } from '@zenstackhq/sdk'; -import { ValidationAcceptor, streamAllContents } from 'langium'; +import { ValidationAcceptor, streamAst } from 'langium'; import pluralize from 'pluralize'; import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; @@ -134,7 +134,7 @@ export default class AttributeApplicationValidator implements AstValidator isFutureExpr(node))) { + if (streamAst(expr).some((node) => isFutureExpr(node))) { accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr }); } } diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 3d6343a9e..21548c89f 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -2,7 +2,6 @@ import { BinaryExpr, Expression, ExpressionType, - isBinaryExpr, isDataModel, isEnum, isNullExpr, @@ -10,7 +9,7 @@ import { } from '@zenstackhq/language/ast'; import { isDataModelFieldReference } from '@zenstackhq/sdk'; import { ValidationAcceptor } from 'langium'; -import { isAuthInvocation } from '../../utils/ast-utils'; +import { isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; import { AstValidator } from '../types'; /** @@ -23,7 +22,7 @@ export default class ExpressionValidator implements AstValidator { if (isAuthInvocation(expr)) { // check was done at link time accept('error', 'auth() cannot be resolved because no "User" model is defined', { node: expr }); - } else if (this.isCollectionPredicate(expr)) { + } else if (isCollectionPredicate(expr)) { accept('error', 'collection predicate can only be used on an array of model type', { node: expr }); } else { accept('error', 'expression cannot be resolved', { @@ -142,8 +141,4 @@ export default class ExpressionValidator implements AstValidator { } } } - - private isCollectionPredicate(expr: Expression) { - return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator); - } } diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index b6ac99576..d76954b43 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -5,7 +5,6 @@ import { DataModelFieldAttribute, Enum, Expression, - MemberAccessExpr, Model, isBinaryExpr, isDataModel, @@ -49,12 +48,12 @@ import { resolved, saveProject, } from '@zenstackhq/sdk'; -import { streamAllContents } from 'langium'; +import { streamAllContents, streamAst, streamContents } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import { FunctionDeclaration, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { name } from '.'; -import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; +import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; import { TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, @@ -237,7 +236,7 @@ export default class PolicyGenerator { } private hasFutureReference(expr: Expression) { - for (const node of this.allNodes(expr)) { + for (const node of streamAst(expr)) { if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { return true; } @@ -434,7 +433,7 @@ export default class PolicyGenerator { private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { return [...allows, ...denies].every((rule) => { - return [...this.allNodes(rule)].every((expr) => { + return streamAst(rule).every((expr) => { if (isThisExpr(expr)) { return false; } @@ -487,6 +486,8 @@ export default class PolicyGenerator { }); }; + // visit a reference or member access expression to build a + // selection path const visit = (node: Expression): string[] | undefined => { if (isReferenceExpr(node)) { const target = resolved(node.target); @@ -509,35 +510,50 @@ export default class PolicyGenerator { return undefined; }; - for (const rule of [...allows, ...denies]) { - for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) { - if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { - // a standalone `this` expression, include all id fields - const model = expr.$resolvedType?.decl as DataModel; - const idFields = getIdFields(model); - idFields.forEach((field) => addPath([field.name])); - continue; - } - - // only care about member access and reference expressions - if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) { - continue; - } - - if (expr.$container.$type === MemberAccessExpr) { - // only visit top-level member access - continue; - } + // collect selection paths from the given expression + const collectReferencePaths = (expr: Expression): string[][] => { + if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) { + // a standalone `this` expression, include all id fields + const model = expr.$resolvedType?.decl as DataModel; + const idFields = getIdFields(model); + return idFields.map((field) => [field.name]); + } + if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) { const path = visit(expr); if (path) { if (isDataModel(expr.$resolvedType?.decl)) { - // member selection ended at a data model field, include its 'id' - path.push('id'); + // member selection ended at a data model field, include its id fields + const idFields = getIdFields(expr.$resolvedType?.decl as DataModel); + return idFields.map((field) => [...path, field.name]); + } else { + return [path]; } - addPath(path); + } else { + return []; } + } else if (isCollectionPredicate(expr)) { + const path = visit(expr.left); + if (path) { + // recurse into RHS + const rhs = collectReferencePaths(expr.right); + // combine path of LHS and RHS + return rhs.map((r) => [...path, ...r]); + } else { + return []; + } + } else { + // recurse + const children = streamContents(expr) + .filter((child): child is Expression => isExpression(child)) + .toArray(); + return children.flatMap((child) => collectReferencePaths(child)); } + }; + + for (const rule of [...allows, ...denies]) { + const paths = collectReferencePaths(rule); + paths.forEach((p) => addPath(p)); } return Object.keys(result).length === 0 ? undefined : result; @@ -556,7 +572,7 @@ export default class PolicyGenerator { this.generateNormalizedAuthRef(model, allows, denies, statements); const hasFieldAccess = [...denies, ...allows].some((rule) => - [...this.allNodes(rule)].some( + streamAst(rule).some( (child) => // this.??? isThisExpr(child) || @@ -724,7 +740,7 @@ export default class PolicyGenerator { ) { // check if any allow or deny rule contains 'auth()' invocation const hasAuthRef = [...allows, ...denies].some((rule) => - [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) + streamAst(rule).some((child) => isAuthInvocation(child)) ); if (hasAuthRef) { @@ -747,9 +763,4 @@ export default class PolicyGenerator { ); } } - - private *allNodes(expr: Expression) { - yield expr; - yield* streamAllContents(expr); - } } diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index b2c0771be..cd6853daa 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,8 +1,10 @@ import { + BinaryExpr, DataModel, DataModelField, Expression, isArrayExpr, + isBinaryExpr, isDataModel, isDataModelField, isInvocationExpr, @@ -150,3 +152,7 @@ export function getAllDeclarationsFromImports(documents: LangiumDocuments, model const imports = resolveTransitiveImports(documents, model); return model.declarations.concat(...imports.map((imp) => imp.declarations)); } + +export function isCollectionPredicate(expr: Expression): expr is BinaryExpr { + return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator); +} diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index 17be22406..74bb9d766 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -17,6 +17,7 @@ import { UnaryExpr, } from '@zenstackhq/language/ast'; import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; +import { match, P } from 'ts-pattern'; import { getIdFields } from './ast-utils'; export class TypeScriptExpressionTransformerError extends Error { @@ -53,7 +54,7 @@ export class TypeScriptExpressionTransformer { * * @param isPostGuard indicates if we're writing for post-update conditions */ - constructor(private readonly options?: Options) {} + constructor(private readonly options: Options) {} /** * Transforms the given expression to a TypeScript expression. @@ -302,33 +303,57 @@ export class TypeScriptExpressionTransformer { } private binary(expr: BinaryExpr, normalizeUndefined: boolean): string { - if (expr.operator === 'in') { - return `(${this.transform(expr.right, false)}?.includes(${this.transform( - expr.left, - normalizeUndefined - )}) ?? false)`; - } else if ( - (expr.operator === '==' || expr.operator === '!=') && - (isThisExpr(expr.left) || isThisExpr(expr.right)) - ) { - // map equality comparison with `this` to id comparison - const _this = isThisExpr(expr.left) ? expr.left : expr.right; - const model = _this.$resolvedType?.decl as DataModel; - const idFields = getIdFields(model); - if (!idFields || idFields.length === 0) { - throw new TypeScriptExpressionTransformerError(`model "${model.name}" does not have an id field`); - } - let result = `allFieldsEqual(${this.transform(expr.left, false)}, + const _default = `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform( + expr.right, + normalizeUndefined + )})`; + + return match(expr.operator) + .with( + 'in', + () => + `(${this.transform(expr.right, false)}?.includes(${this.transform( + expr.left, + normalizeUndefined + )}) ?? false)` + ) + .with(P.union('==', '!='), () => { + if (isThisExpr(expr.left) || isThisExpr(expr.right)) { + // map equality comparison with `this` to id comparison + const _this = isThisExpr(expr.left) ? expr.left : expr.right; + const model = _this.$resolvedType?.decl as DataModel; + const idFields = getIdFields(model); + if (!idFields || idFields.length === 0) { + throw new TypeScriptExpressionTransformerError( + `model "${model.name}" does not have an id field` + ); + } + let result = `allFieldsEqual(${this.transform(expr.left, false)}, ${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`; - if (expr.operator === '!=') { - result = `!${result}`; - } - return result; - } else { - return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform( - expr.right, - normalizeUndefined - )})`; - } + if (expr.operator === '!=') { + result = `!${result}`; + } + return result; + } else { + return _default; + } + }) + .with(P.union('?', '!', '^'), (op) => this.collectionPredicate(expr, op, normalizeUndefined)) + .otherwise(() => _default); + } + + private collectionPredicate(expr: BinaryExpr, operator: '?' | '!' | '^', normalizeUndefined: boolean) { + const operand = this.transform(expr.left, normalizeUndefined); + const innerTransformer = new TypeScriptExpressionTransformer({ + ...this.options, + fieldReferenceContext: '_item', + }); + const predicate = innerTransformer.transform(expr.right, normalizeUndefined); + + return match(operator) + .with('?', () => `!!((${operand})?.some((_item: any) => ${predicate}))`) + .with('!', () => `!!((${operand})?.every((_item: any) => ${predicate}))`) + .with('^', () => `!((${operand})?.some((_item: any) => ${predicate}))`) + .exhaustive(); } } diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index 209876f25..e43fd370d 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -746,4 +746,99 @@ describe('With Policy: field-level policy', () => { r = await withPolicy({ id: 2 }).user.findFirst(); expect(r.username).toBeUndefined(); }); + + it('collection predicate', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + foos Foo[] + a Int @allow('read', foos?[x > 0 && bars![y > 0]]) + b Int @allow('read', foos^[x == 1]) + + @@allow('all', true) + } + + model Foo { + id Int @id @default(autoincrement()) + x Int + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + bars Bar[] + + @@allow('all', true) + } + + model Bar { + id Int @id @default(autoincrement()) + y Int + foo Foo @relation(fields: [fooId], references: [id]) + fooId Int + + @@allow('all', true) + } + ` + ); + + const db = withPolicy(); + + await prisma.user.create({ + data: { + id: 1, + a: 1, + b: 2, + foos: { + create: [ + { x: 0, bars: { create: [{ y: 1 }] } }, + { x: 1, bars: { create: [{ y: 0 }, { y: 1 }] } }, + ], + }, + }, + }); + + let r = await db.user.findUnique({ where: { id: 1 } }); + expect(r.a).toBeUndefined(); + expect(r.b).toBeUndefined(); + + await prisma.user.create({ + data: { + id: 2, + a: 1, + b: 2, + foos: { + create: [{ x: 2, bars: { create: [{ y: 0 }, { y: 1 }] } }], + }, + }, + }); + r = await db.user.findUnique({ where: { id: 2 } }); + expect(r.a).toBeUndefined(); + expect(r.b).toBe(2); + + await prisma.user.create({ + data: { + id: 3, + a: 1, + b: 2, + foos: { + create: [{ x: 2 }], + }, + }, + }); + r = await db.user.findUnique({ where: { id: 3 } }); + expect(r.a).toBe(1); + + await prisma.user.create({ + data: { + id: 4, + a: 1, + b: 2, + foos: { + create: [{ x: 2, bars: { create: [{ y: 1 }, { y: 2 }] } }], + }, + }, + }); + r = await db.user.findUnique({ where: { id: 4 } }); + expect(r.a).toBe(1); + expect(r.b).toBe(2); + }); }); diff --git a/tests/integration/tests/regression/issue-703.test.ts b/tests/integration/tests/regression/issue-703.test.ts new file mode 100644 index 000000000..203100496 --- /dev/null +++ b/tests/integration/tests/regression/issue-703.test.ts @@ -0,0 +1,26 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression: issue 703', () => { + it('regression', async () => { + await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + name String? + admin Boolean @default(false) + + companiesWorkedFor Company[] + + username String @unique @allow("all", auth() == this) @allow('read', companiesWorkedFor?[owner == auth()]) @allow("all", auth().admin) + } + + model Company { + id Int @id @default(autoincrement()) + name String? + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + } + ` + ); + }); +});