diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 9a17414e8..18c826158 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,6 +1,6 @@ -import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast'; +import { BinaryExpr, Expression, isBinaryExpr, isEnum } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; -import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils'; +import { isAuthInvocation } from '../../utils/ast-utils'; import { AstValidator } from '../types'; /** @@ -37,21 +37,12 @@ export default class ExpressionValidator implements AstValidator { private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) { switch (expr.operator) { case 'in': { - if (!getDataModelFieldReference(expr.left)) { - accept('error', 'left operand of "in" must be a field reference', { node: expr.left }); - } - if (typeof expr.left.$resolvedType?.decl !== 'string' && !isEnum(expr.left.$resolvedType?.decl)) { accept('error', 'left operand of "in" must be of scalar type', { node: expr.left }); } - if ( - !( - isArrayExpr(expr.right) && - expr.right.items.every((item) => isLiteralExpr(item) || isEnumFieldReference(item)) - ) - ) { - accept('error', 'right operand of "in" must be an array of literals or enum values', { + if (!expr.right.$resolvedType?.array) { + accept('error', 'right operand of "in" must be an array', { node: expr.right, }); } diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 03fe11fb1..238eed6c1 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -153,14 +153,35 @@ export class ExpressionWriter { } private writeIn(expr: BinaryExpr) { + const leftIsFieldAccess = this.isFieldAccess(expr.left); + const rightIsFieldAccess = this.isFieldAccess(expr.right); + this.block(() => { - this.writeFieldCondition( - expr.left, - () => { - this.plain(expr.right); - }, - 'in' - ); + if (!leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' without referencing fields + this.guard(() => this.plain(expr)); + } else if (leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' with left referencing a field, right is an array literal + this.writeFieldCondition( + expr.left, + () => { + this.plain(expr.right); + }, + 'in' + ); + } else if (!leftIsFieldAccess && rightIsFieldAccess) { + // 'in' with right referencing an array field, left is a literal + // transform it into a 'has' filter + this.writeFieldCondition( + expr.right, + () => { + this.plain(expr.left); + }, + 'has' + ); + } else { + throw new PluginError('"in" operator cannot be used with field references on both sides'); + } }); } @@ -520,6 +541,12 @@ export class ExpressionWriter { } if (FILTER_OPERATOR_FUNCTIONS.includes(funcDecl.name)) { + if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) { + // filter functions without referencing fields + this.block(() => this.guard(() => this.plain(expr))); + return; + } + let valueArg = expr.args[1]?.value; // isEmpty function is zero arity, it's mapped to a boolean literal diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 7d04d1691..62079cef0 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -52,6 +52,7 @@ export default class PolicyGenerator { const project = createProject(); const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); + sf.addStatements('/* eslint-disable */'); sf.addImportDeclaration({ namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }], @@ -361,7 +362,7 @@ export default class PolicyGenerator { func.addStatements( `const user = hasAllFields(context.user, [${userIdFields .map((f) => "'" + f.name + "'") - .join(', ')}]) ? context.user : null;` + .join(', ')}]) ? context.user as any : null;` ); } diff --git a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts index cb6dfba5e..98dde9004 100644 --- a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts +++ b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts @@ -12,7 +12,8 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { PluginError } from '@zenstackhq/sdk'; +import { getLiteral, PluginError } from '@zenstackhq/sdk'; +import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants'; import { isAuthInvocation } from '../../utils/ast-utils'; import { isFutureExpr } from './utils'; @@ -28,17 +29,17 @@ export default class TypeScriptExpressionTransformer { constructor(private readonly isPostGuard = false) {} /** - * - * @param expr + * Transforms the given expression to a TypeScript expression. + * @param normalizeUndefined if undefined values should be normalized to null * @returns */ - transform(expr: Expression): string { + transform(expr: Expression, normalizeUndefined = true): string { switch (expr.$type) { case LiteralExpr: return this.literal(expr as LiteralExpr); case ArrayExpr: - return this.array(expr as ArrayExpr); + return this.array(expr as ArrayExpr, normalizeUndefined); case NullExpr: return this.null(); @@ -50,16 +51,16 @@ export default class TypeScriptExpressionTransformer { return this.reference(expr as ReferenceExpr); case InvocationExpr: - return this.invocation(expr as InvocationExpr); + return this.invocation(expr as InvocationExpr, normalizeUndefined); case MemberAccessExpr: - return this.memberAccess(expr as MemberAccessExpr); + return this.memberAccess(expr as MemberAccessExpr, normalizeUndefined); case UnaryExpr: - return this.unary(expr as UnaryExpr); + return this.unary(expr as UnaryExpr, normalizeUndefined); case BinaryExpr: - return this.binary(expr as BinaryExpr); + return this.binary(expr as BinaryExpr, normalizeUndefined); default: throw new PluginError(`Unsupported expression type: ${expr.$type}`); @@ -72,7 +73,7 @@ export default class TypeScriptExpressionTransformer { return 'id'; } - private memberAccess(expr: MemberAccessExpr) { + private memberAccess(expr: MemberAccessExpr, normalizeUndefined: boolean) { if (!expr.member.ref) { throw new PluginError(`Unresolved MemberAccessExpr`); } @@ -85,14 +86,71 @@ export default class TypeScriptExpressionTransformer { } return expr.member.ref.name; } else { - // normalize field access to null instead of undefined to avoid accidentally use undefined in filter - return `(${this.transform(expr.operand)}?.${expr.member.ref.name} ?? null)`; + if (normalizeUndefined) { + // normalize field access to null instead of undefined to avoid accidentally use undefined in filter + return `(${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name} ?? null)`; + } else { + return `${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name}`; + } } } - private invocation(expr: InvocationExpr) { + private invocation(expr: InvocationExpr, normalizeUndefined: boolean) { + if (!expr.function.ref) { + throw new PluginError(`Unresolved InvocationExpr`); + } + if (isAuthInvocation(expr)) { return 'user'; + } else if (FILTER_OPERATOR_FUNCTIONS.includes(expr.function.ref.name)) { + // arguments are already type-checked + + const arg0 = this.transform(expr.args[0].value, false); + let result: string; + switch (expr.function.ref.name) { + case 'contains': { + const caseInsensitive = getLiteral(expr.args[2]?.value) === true; + if (caseInsensitive) { + result = `${arg0}?.toLowerCase().includes(${this.transform( + expr.args[1].value, + normalizeUndefined + )}?.toLowerCase())`; + } else { + result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`; + } + break; + } + case 'search': + throw new PluginError('"search" function must be used against a field'); + case 'startsWith': + result = `${arg0}?.startsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`; + break; + case 'endsWith': + result = `${arg0}?.endsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`; + break; + case 'has': + result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`; + break; + case 'hasEvery': + result = `${this.transform( + expr.args[1].value, + normalizeUndefined + )}?.every((item) => ${arg0}?.includes(item))`; + break; + case 'hasSome': + result = `${this.transform( + expr.args[1].value, + normalizeUndefined + )}?.some((item) => ${arg0}?.includes(item))`; + break; + case 'isEmpty': + result = `${arg0}?.length === 0`; + break; + default: + throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`); + } + + return `(${result} ?? false)`; } else { throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`); } @@ -121,8 +179,8 @@ export default class TypeScriptExpressionTransformer { return 'null'; } - private array(expr: ArrayExpr) { - return `[${expr.items.map((item) => this.transform(item)).join(', ')}]`; + private array(expr: ArrayExpr, normalizeUndefined: boolean) { + return `[${expr.items.map((item) => this.transform(item, normalizeUndefined)).join(', ')}]`; } private literal(expr: LiteralExpr) { @@ -133,11 +191,18 @@ export default class TypeScriptExpressionTransformer { } } - private unary(expr: UnaryExpr): string { - return `(${expr.operator} ${this.transform(expr.operand)})`; + private unary(expr: UnaryExpr, normalizeUndefined: boolean): string { + return `(${expr.operator} ${this.transform(expr.operand, normalizeUndefined)})`; } - private binary(expr: BinaryExpr): string { - return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`; + private binary(expr: BinaryExpr, normalizeUndefined: boolean): string { + if (expr.operator === 'in') { + return `(${this.transform(expr.right, false)}?.includes(${this.transform( + expr.left, + normalizeUndefined + )}) ?? false)`; + } else { + return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`; + } } } diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index 271dcde64..61a539e78 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -43,6 +43,7 @@ export default async function run(model: Model, options: PluginOptions) { } const sf = project.createSourceFile(path.join(output, 'model-meta.ts'), undefined, { overwrite: true }); + sf.addStatements('/* eslint-disable */'); sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(dataModels, writer) }], diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index f76ef7251..8787f6e57 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -99,9 +99,10 @@ function future(): Any { } /* - * If the field value contains the search string + * If the field value contains the search string. By default, the search is case-sensitive, + * but you can override the behavior with the "caseInSensitive" argument. */ -function contains(field: String, search: String, caseSensitive: Boolean?): Boolean { +function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { } /* diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index e3296a720..178abe78e 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -943,7 +943,7 @@ describe('Expression Writer Tests', () => { ); }); - it('filter operators', async () => { + it('filter operators field access', async () => { await check( ` enum Role { @@ -1134,6 +1134,153 @@ describe('Expression Writer Tests', () => { }); }); +it('filter operators non-field access', async () => { + const userInit = `{ id: 'user1', email: 'test@zenstack.dev', roles: [Role.ADMIN] }`; + const prelude = ` + enum Role { + USER + ADMIN + } + + model User { + id String @id + email String + roles Role[] + } + `; + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', ADMIN in auth().roles) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + roles Role[] + @@allow('all', ADMIN in roles) + } + `, + (model) => model.attributes[0].args[1].value, + `{roles:{has:Role.ADMIN}}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', contains(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.email?.includes('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', contains(auth().email, 'test', true)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.email?.toLowerCase().includes('test'?.toLowerCase())??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', startsWith(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.email?.startsWith('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', endsWith(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.email?.endsWith('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', has(auth().roles, ADMIN)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', hasEvery(auth().roles, [ADMIN, USER])) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', hasSome(auth().roles, [USER, ADMIN])) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', isEmpty(auth().roles)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:(user?.roles?.length===0??false)}`, + userInit + ); +}); + async function check(schema: string, getExpr: (model: DataModel) => Expression, expected: string, userInit?: string) { if (!schema.includes('datasource ')) { schema = @@ -1155,12 +1302,6 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, overwrite: true, }); - // inject user variable - sf.addVariableStatement({ - declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'user', initializer: userInit ?? '{ id: "user1" }' }], - }); - // inject enums model.declarations .filter((d) => isEnum(d)) @@ -1180,6 +1321,12 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, }); }); + // inject user variable + sf.addVariableStatement({ + declarationKind: VariableDeclarationKind.Const, + declarations: [{ name: 'user', initializer: userInit ?? '{ id: "user1" }' }], + }); + sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [ @@ -1197,7 +1344,6 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, for (const d of project.getPreEmitDiagnostics()) { console.warn(`${d.getLineNumber()}: ${d.getMessageText()}`); } - console.log(`Generated source: ${sourcePath}`); throw new Error('Compilation errors occurred'); } diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index d541d43a9..faf88eb9f 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -461,17 +461,6 @@ describe('Attribute tests', () => { `) ).toContain('argument is not assignable to parameter'); - expect( - await loadModelWithError(` - ${prelude} - model M { - id String @id - i Int[] - @@allow('all', 1 in i) - } - `) - ).toContain('left operand of "in" must be a field reference'); - expect( await loadModelWithError(` ${prelude} @@ -481,7 +470,7 @@ describe('Attribute tests', () => { @@allow('all', i in 1) } `) - ).toContain('right operand of "in" must be an array of literals or enum values'); + ).toContain('right operand of "in" must be an array'); expect( await loadModelWithError(` diff --git a/tests/integration/tests/e2e/filter-function-coverage.test.ts b/tests/integration/tests/e2e/filter-function-coverage.test.ts index daedbbef3..a7d6088c4 100644 --- a/tests/integration/tests/e2e/filter-function-coverage.test.ts +++ b/tests/integration/tests/e2e/filter-function-coverage.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; describe('Filter Function Coverage Tests', () => { - it('contains case-sensitive', async () => { + it('contains case-sensitive field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -16,7 +16,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'bac' } })).toResolveTruthy(); }); - it('startsWith', async () => { + it('contains case-sensitive non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', contains(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bcd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('startsWith field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -31,7 +51,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'abc' } })).toResolveTruthy(); }); - it('endsWith', async () => { + it('startsWith non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', startsWith(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abc' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('endsWith field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -46,7 +86,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'bca' } })).toResolveTruthy(); }); - it('in', async () => { + it('endsWith non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', endsWith(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bca' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('in left field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -60,4 +120,24 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'c' } })).toBeRejectedByPolicy(); await expect(withPresets().foo.create({ data: { string: 'b' } })).toResolveTruthy(); }); + + it('in non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', auth().name in ['abc', 'bcd']) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abc' }).foo.create({ data: {} })).toResolveTruthy(); + }); });