From 2a3219a1c0a6cf78cce34b9aef4e08d90564cb0d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 16 Aug 2023 18:29:49 +0800 Subject: [PATCH 1/6] WIP: field-level-access-control --- README.md | 8 +- .../src/enhancements/policy/handler.ts | 4 +- .../src/enhancements/policy/policy-utils.ts | 51 ++++++- packages/runtime/src/enhancements/types.ts | 1 + .../access-policy/policy-guard-generator.ts | 132 ++++++++++++++++-- packages/schema/src/res/stdlib.zmodel | 10 ++ .../with-policy/field-level-policy.test.ts | 45 ++++++ 7 files changed, 233 insertions(+), 18 deletions(-) create mode 100644 tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts diff --git a/README.md b/README.md index 41cb26e15..4860247a9 100644 --- a/README.md +++ b/README.md @@ -62,14 +62,14 @@ The `zenstack` CLI transpiles the ZModel into a standard Prisma schema, which yo At runtime, transparent proxies are created around Prisma clients for intercepting queries and mutations to enforce access policies. ```ts -import { withPolicy } from '@zenstackhq/runtime'; +import { enhance } from '@zenstackhq/runtime'; // a regular Prisma client const prisma = new PrismaClient(); async function getPosts(userId: string) { // create an enhanced Prisma client that has access control enabled - const enhanced = withPolicy(prisma, { user: userId }); + const enhanced = enhance(prisma, { user: userId }); // only posts that're visible to the user will be returned return enhanced.post.findMany(); @@ -84,14 +84,14 @@ Server adapter packages help you wrap an access-control-enabled Prisma client in // pages/api/model/[...path].ts import { requestHandler } from '@zenstackhq/next'; -import { withPolicy } from '@zenstackhq/runtime'; +import { enhance } from '@zenstackhq/runtime'; import { getSessionUser } from '@lib/auth'; import { prisma } from '@lib/db'; // Mount Prisma-style APIs: "/api/model/post/findMany", "/api/model/post/create", etc. // Can be configured to provide standard RESTful APIs (using JSON:API) instead. export default requestHandler({ - getPrisma: (req, res) => withPolicy(prisma, { user: getSessionUser(req, res) }), + getPrisma: (req, res) => enhance(prisma, { user: getSessionUser(req, res) }), }); ``` diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 1a572315f..235282932 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -69,11 +69,13 @@ export class PolicyProxyHandler implements Pr return null; } + this.utils.injectReadCheckSelect(args); + if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findUnique(args); - this.utils.postProcessForRead(result); + this.utils.postProcessForRead(result, args); return result; } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index cedadb5cd..22239cbfc 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -666,6 +666,47 @@ export class PolicyUtil { return { result, error: undefined }; } + injectReadCheckSelect(args: any) { + const readFieldSelect = this.getReadFieldSelect(args.model); + if (!readFieldSelect) { + return; + } + this.doInjectReadCheckSelect(args, readFieldSelect); + } + + private doInjectReadCheckSelect(args: any, input: any) { + let target: any; + let isInclude = false; + + if (args.select) { + target = args.select; + isInclude = false; + } else if (args.include) { + target = args.include; + isInclude = true; + } else { + target = args.select = {}; + isInclude = false; + } + + if (!isInclude) { + // merge selects + for (const [k, v] of Object.entries(input.select)) { + if (v === true) { + if (!target[k]?.select) { + target[k].select = true; + } + } + } + } + + for (const [k, v] of Object.entries(input)) { + if (typeof v === 'object' && v?.select) { + this.doInjectReadCheckSelect(target[k], v); + } + } + } + //#endregion //#region Errors @@ -712,6 +753,14 @@ export class PolicyUtil { return guard.preValueSelect; } + getReadFieldSelect(model: string): object | undefined { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + return guard.readFieldSelect; + } + private hasFieldValidation(model: string): boolean { return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; } @@ -732,7 +781,7 @@ export class PolicyUtil { /** * Post processing checks and clean-up for read model entities. */ - postProcessForRead(data: any) { + postProcessForRead(data: any, queryArgs: any) { if (data === null || data === undefined) { return; } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 72ea092d0..55953f813 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -40,6 +40,7 @@ export type PolicyDef = { create_input: InputCheckFunc; } & { preValueSelect?: object; + readFieldSelect?: object; } >; validation: Record; diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index bacdc688c..006d1a835 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -1,6 +1,11 @@ import { DataModel, + DataModelAttribute, + DataModelField, + DataModelFieldAttribute, Expression, + MemberAccessExpr, + Model, isBinaryExpr, isDataModel, isDataModelField, @@ -11,29 +16,27 @@ import { isReferenceExpr, isThisExpr, isUnaryExpr, - MemberAccessExpr, - Model, } from '@zenstackhq/language/ast'; import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; import { + ExpressionContext, + PluginError, + PluginOptions, + RUNTIME_PACKAGE, analyzePolicies, createProject, emitProject, - ExpressionContext, getDataModels, getLiteral, getPrismaClientImportSpec, hasAttribute, hasValidationAttributes, isForeignKeyField, - PluginError, - PluginOptions, - resolved, resolvePath, - RUNTIME_PACKAGE, + resolved, saveProject, } from '@zenstackhq/sdk'; -import { streamAllContents } from 'langium'; +import { findRootNode, streamAllContents } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import { @@ -143,8 +146,10 @@ export default class PolicyGenerator { } } - private getPolicyExpressions(model: DataModel, kind: PolicyKind, operation: PolicyOperationKind) { - const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === `@@${kind}`); + private getPolicyExpressions(target: DataModel | DataModelField, kind: PolicyKind, operation: PolicyOperationKind) { + const attributes = target.attributes as (DataModelAttribute | DataModelFieldAttribute)[]; + const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; + const attrs = attributes.filter((attr) => attr.decl.ref?.name === attrName); const checkOperation = operation === 'postUpdate' ? 'update' : operation; @@ -248,7 +253,7 @@ export default class PolicyGenerator { result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { - const preValueSelect = this.generatePreValueSelect(model, allows, denies); + const preValueSelect = this.generatePreValueSelect(allows, denies); if (preValueSelect) { result['preValueSelect'] = preValueSelect; } @@ -259,10 +264,113 @@ export default class PolicyGenerator { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind + '_input'] = inputCheckFunc.getName()!; } + + const allFieldsAllows: Expression[] = []; + const allFieldsDenies: Expression[] = []; + + for (const field of model.fields) { + const allows = this.getPolicyExpressions(field, 'allow', 'read'); + const denies = this.getPolicyExpressions(field, 'deny', 'read'); + allFieldsAllows.push(...allows); + allFieldsDenies.push(...denies); + + if (denies.length === 0 && allows.length === 0) { + continue; + } + + const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[`readFieldCheck$${field.name}`] = guardFunc.getName()!; + } + + const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); + if (readFieldCheckSelect) { + result[`readFieldSelect`] = readFieldCheckSelect; + } } return result; } + private generateReadFieldGuardFunction( + sourceFile: SourceFile, + field: DataModelField, + allows: Expression[], + denies: Expression[] + ) { + const statements: (string | WriterFunction | StatementStructures)[] = []; + + // check if any allow or deny rule contains 'auth()' invocation + const hasAuthRef = [...denies, ...allows].some((rule) => + [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) + ); + + if (hasAuthRef) { + const root = findRootNode(field) as Model; + const userModel = root.declarations.find( + (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' + ); + if (!userModel) { + throw new PluginError(name, 'User model not found'); + } + const userIdFields = getIdFields(userModel); + if (!userIdFields || userIdFields.length === 0) { + throw new PluginError(name, 'User model does not have an id field'); + } + + // normalize user to null to avoid accidentally use undefined in filter + statements.push( + `const user = hasAllFields(context.user, [${userIdFields + .map((f) => "'" + f.name + "'") + .join(', ')}]) ? context.user as any : null;` + ); + } + + statements.push((writer) => { + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + fieldReferenceContext: 'input', + }); + + let expr = + denies.length > 0 + ? '!(' + + denies + .map((deny) => { + return transformer.transform(deny); + }) + .join(' || ') + + ')' + : undefined; + + const allowStmt = allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || '); + + expr = expr ? `${expr} && (${allowStmt})` : allowStmt; + writer.write('return ' + expr); + }); + + const func = sourceFile.addFunction({ + name: `${field.$container.name}$${field.name}_read`, + returnType: 'boolean', + parameters: [ + { + name: 'input', + type: 'any', + }, + { + name: 'context', + type: 'QueryContext', + }, + ], + statements, + }); + + return func; + } + private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { return [...allows, ...denies].every((rule) => { return [...this.allNodes(rule)].every((expr) => { @@ -301,7 +409,7 @@ export default class PolicyGenerator { // generates an object that can be used as the 'select' argument when fetching pre-update // entity value - private generatePreValueSelect(model: DataModel, allows: Expression[], denies: Expression[]): object { + private generatePreValueSelect(allows: Expression[], denies: Expression[]): object { // eslint-disable-next-line @typescript-eslint/no-explicit-any const result: any = {}; const addPath = (path: string[]) => { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index d3a5574d1..a9a892873 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -351,11 +351,21 @@ attribute @@schema(_ name: String) @@@prisma */ attribute @@allow(_ operation: String, _ condition: Boolean) +/** + * Defines an access policy that allows a set of operations when the given condition is true. + */ +attribute @allow(_ operation: String, _ condition: Boolean) + /** * Defines an access policy that denies a set of operations when the given condition is true. */ attribute @@deny(_ operation: String, _ condition: Boolean) +/** + * Defines an access policy that denies a set of operations when the given condition is true. + */ +attribute @deny(_ operation: String, _ condition: Boolean) + /** * Indicates that the field is a password field and needs to be hashed before persistence. * diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts new file mode 100644 index 000000000..cfd5faa82 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -0,0 +1,45 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('With Policy: field-level policy', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('read', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + admin Boolean @default(false) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('read', x > 0 || auth().admin) + + @@allow('all', true) + } + ` + ); + + await prisma.model.create({ + data: { + id: 1, + x: 0, + y: 0, + }, + }); + + const db = withPolicy(); + const r = await db.model.findUnique({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + }); +}); From 894e82a8b27c1f56ee98edda7cbd58369be18400 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 19 Aug 2023 09:46:25 +0800 Subject: [PATCH 2/6] WIP --- .../src/enhancements/policy/handler.ts | 41 +- .../src/enhancements/policy/policy-utils.ts | 195 ++++- packages/runtime/src/enhancements/types.ts | 7 +- packages/schema/src/language-server/utils.ts | 30 +- .../attribute-application-validator.ts | 285 ++++++++ .../validator/attribute-validator.ts | 7 +- .../validator/datamodel-validator.ts | 9 +- .../validator/enum-validator.ts | 11 +- .../validator/function-decl-validator.ts | 6 +- .../function-invocation-validator.ts | 3 +- .../src/language-server/validator/utils.ts | 132 ---- .../src/language-server/zmodel-linker.ts | 2 +- .../access-policy/expression-writer.ts | 2 +- .../access-policy/policy-guard-generator.ts | 80 ++- .../schema/src/plugins/access-policy/utils.ts | 10 - packages/schema/src/plugins/zod/generator.ts | 2 +- packages/schema/src/res/stdlib.zmodel | 2 +- packages/schema/src/utils/ast-utils.ts | 2 +- .../typescript-expression-transformer.ts | 4 +- .../validation/attribute-validation.test.ts | 54 ++ packages/sdk/src/constants.ts | 2 + packages/sdk/src/utils.ts | 20 +- pnpm-lock.yaml | 32 +- .../with-policy/field-level-policy.test.ts | 675 +++++++++++++++++- tests/integration/tsconfig.json | 3 +- 25 files changed, 1346 insertions(+), 270 deletions(-) create mode 100644 packages/schema/src/language-server/validator/attribute-application-validator.ts delete mode 100644 packages/schema/src/plugins/access-policy/utils.ts diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 235282932..78a5d1400 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -64,18 +64,19 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } + const origArgs = args; args = this.utils.clone(args); if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return null; } - this.utils.injectReadCheckSelect(args); + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findUnique(args); - this.utils.postProcessForRead(result, args); + this.utils.postProcessForRead(result, this.model, origArgs); return result; } @@ -87,58 +88,70 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } + const origArgs = args; args = this.utils.clone(args); if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { throw this.utils.notFound(this.model); } + this.utils.injectReadCheckSelect(this.model, args); + if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findUniqueOrThrow(args); - this.utils.postProcessForRead(result); + this.utils.postProcessForRead(result, this.model, origArgs); return result; } async findFirst(args: any) { + const origArgs = args; args = args ? this.utils.clone(args) : {}; if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return null; } + this.utils.injectReadCheckSelect(this.model, args); + if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findFirst(args); - this.utils.postProcessForRead(result); + this.utils.postProcessForRead(result, this.model, origArgs); return result; } async findFirstOrThrow(args: any) { + const origArgs = args; args = args ? this.utils.clone(args) : {}; if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { throw this.utils.notFound(this.model); } + this.utils.injectReadCheckSelect(this.model, args); + if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findFirstOrThrow(args); - this.utils.postProcessForRead(result); + this.utils.postProcessForRead(result, this.model, origArgs); return result; } async findMany(args: any) { + const origArgs = args; args = args ? this.utils.clone(args) : {}; if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return []; } + this.utils.injectReadCheckSelect(this.model, args); + if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`); } const result = await this.modelClient.findMany(args); - this.utils.postProcessForRead(result); + this.utils.postProcessForRead(result, this.model, origArgs); return result; } @@ -257,7 +270,7 @@ export class PolicyProxyHandler implements Pr if (backLinkField?.isRelationOwner) { // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args.where, 'update', db); + await this.utils.checkPolicyForUnique(model, args.where, 'update', db, args); } } @@ -302,7 +315,7 @@ export class PolicyProxyHandler implements Pr // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args, 'update', db); + await this.utils.checkPolicyForUnique(model, args, 'update', db, args); } } }, @@ -599,7 +612,7 @@ export class PolicyProxyHandler implements Pr const backLinkField = this.utils.getModelField(model, context.field.backLink); if (backLinkField.isRelationOwner) { // update happens on the related model, require updatable - await this.utils.checkPolicyForUnique(model, args, 'update', db); + await this.utils.checkPolicyForUnique(model, args, 'update', db, args); // register post-update check await _registerPostUpdateCheck(model, args); @@ -640,7 +653,7 @@ export class PolicyProxyHandler implements Pr this.utils.tryReject(db, this.model, 'update'); // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // handles the case where id fields are updated const ids = this.utils.clone(existing); @@ -723,7 +736,7 @@ export class PolicyProxyHandler implements Pr // update case // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // register post-update check await _registerPostUpdateCheck(model, uniqueFilter); @@ -791,7 +804,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkExistence(db, model, uniqueFilter, true); // check delete guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db); + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); }, deleteMany: async (model, args, context) => { @@ -944,7 +957,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkExistence(tx, this.model, args.where, true); // inject delete guard - await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx); + await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); // proceed with the deletion if (this.shouldLogQuery) { @@ -1039,7 +1052,7 @@ export class PolicyProxyHandler implements Pr private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record) { await Promise.all( postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) => - this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, preValue) + this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) ) ); } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 93cf5d832..bff46ae13 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -9,7 +9,7 @@ import { AuthUser, DbClientContract, DbOperations, FieldInfo, PolicyOperationKin import { getVersion } from '../../version'; import { getFields, resolveField } from '../model-meta'; import { NestedWriteVisitorContext } from '../nested-write-vistor'; -import type { InputCheckFunc, ModelMeta, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; +import type { InputCheckFunc, ModelMeta, PolicyDef, PolicyFunc, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { enumerate, formatObject, @@ -45,14 +45,14 @@ export class PolicyUtil { /** * Creates a conjunction of a list of query conditions. */ - and(...conditions: (boolean | object)[]): object { + and(...conditions: (boolean | object | undefined)[]): object { return this.reduce({ AND: conditions }); } /** * Creates a disjunction of a list of query conditions. */ - or(...conditions: (boolean | object)[]): object { + or(...conditions: (boolean | object | undefined)[]): object { return this.reduce({ OR: conditions }); } @@ -173,7 +173,7 @@ export class PolicyUtil { throw this.unknownError(`unable to load policy guard for ${model}`); } - const provider: PolicyFunc | boolean | undefined = guard[operation]; + const provider = guard[operation]; if (typeof provider === 'boolean') { return this.reduce(provider); } @@ -185,6 +185,24 @@ export class PolicyUtil { return this.reduce(r); } + getFieldUpdateAuthGuard(db: Record, model: string, field: string): object { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + + const provider = guard[`updateFieldGuard$${field}`]; + if (typeof provider === 'boolean') { + return this.reduce(provider); + } + + if (!provider) { + return this.makeTrue(); + } + const r = provider({ user: this.user }, db); + return this.reduce(r); + } + /** * Checks if the given model has a policy guard for the given operation. */ @@ -225,12 +243,23 @@ export class PolicyUtil { * Injects model auth guard as where clause. */ async injectAuthGuard(db: Record, args: any, model: string, operation: PolicyOperationKind) { - const guard = this.getAuthGuard(db, model, operation); + let guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { args.where = this.makeFalse(); return false; } + if (operation === 'update' && args) { + // merge field-level policy guards + const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); + if (fieldUpdateGuard.rejectedByField) { + args.where = this.makeFalse(); + return false; + } else if (fieldUpdateGuard.guard) { + guard = this.and(guard, fieldUpdateGuard.guard); + } + } + if (args.where) { // inject into relation fields: // to-many: some/none/every @@ -545,13 +574,30 @@ export class PolicyUtil { uniqueFilter: any, operation: PolicyOperationKind, db: Record, + args: any, preValue?: any ) { - const guard = this.getAuthGuard(db, model, operation, preValue); + let guard = this.getAuthGuard(db, model, operation, preValue); if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation, `entity ${formatObject(uniqueFilter)} failed policy check`); } + if (operation === 'update' && args) { + // merge field-level policy guards + const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); + if (fieldUpdateGuard.rejectedByField) { + throw this.deniedByPolicy( + model, + 'update', + `entity ${formatObject(uniqueFilter)} failed update policy check for field "${ + fieldUpdateGuard.rejectedByField + }"` + ); + } else if (fieldUpdateGuard.guard) { + guard = this.and(guard, fieldUpdateGuard.guard); + } + } + // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; @@ -600,6 +646,21 @@ export class PolicyUtil { } } + private getFieldUpdateGuards(db: Record, model: string, args: any) { + let allFieldGuards; + for (const [k, v] of Object.entries(args.data ?? args)) { + if (typeof v === 'undefined') { + continue; + } + const fieldGuard = this.getFieldUpdateAuthGuard(db, model, k); + if (this.isFalse(fieldGuard)) { + return { guard: allFieldGuards, rejectedByField: k }; + } + allFieldGuards = this.and(allFieldGuards, fieldGuard); + } + return { guard: allFieldGuards, rejectedByField: undefined }; + } + /** * Tries rejecting a request based on static "false" policy. */ @@ -648,6 +709,7 @@ export class PolicyUtil { uniqueFilter = this.clone(uniqueFilter); this.flattenGeneratedUniqueField(model, uniqueFilter); const readArgs = { select: selectInclude.select, include: selectInclude.include, where: uniqueFilter }; + const error = this.deniedByPolicy( model, operation, @@ -660,6 +722,9 @@ export class PolicyUtil { return { error, result: undefined }; } + // inject select needed for field-level read checks + this.injectReadCheckSelect(model, readArgs); + if (this.shouldLogQuery) { this.logger.info(`[policy] checking read-back, \`findFirst\` ${model}:\n${formatObject(readArgs)}`); } @@ -668,19 +733,23 @@ export class PolicyUtil { return { error, result: undefined }; } - this.postProcessForRead(result); + this.postProcessForRead(result, model, selectInclude); return { result, error: undefined }; } - injectReadCheckSelect(args: any) { - const readFieldSelect = this.getReadFieldSelect(args.model); + injectReadCheckSelect(model: string, args: any) { + const readFieldSelect = this.getReadFieldSelect(model); if (!readFieldSelect) { return; } - this.doInjectReadCheckSelect(args, readFieldSelect); + this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } - private doInjectReadCheckSelect(args: any, input: any) { + private doInjectReadCheckSelect(model: string, args: any, input: any) { + if (!input.select) { + return; + } + let target: any; let isInclude = false; @@ -691,7 +760,7 @@ export class PolicyUtil { target = args.include; isInclude = true; } else { - target = args.select = {}; + target = args.select = this.makeAllScalarFieldSelect(model); isInclude = false; } @@ -699,20 +768,44 @@ export class PolicyUtil { // merge selects for (const [k, v] of Object.entries(input.select)) { if (v === true) { - if (!target[k]?.select) { - target[k].select = true; + if (!target[k]) { + target[k] = true; } } } } - for (const [k, v] of Object.entries(input)) { + for (const [k, v] of Object.entries(input.select)) { if (typeof v === 'object' && v?.select) { - this.doInjectReadCheckSelect(target[k], v); + const field = resolveField(this.modelMeta, model, k); + if (field && field.isDataModel) { + // recurse into relation + if (isInclude && target[k] === true) { + // select all fields for the relation + target[k] = { select: this.makeAllScalarFieldSelect(field.type) }; + } else if (!target[k]) { + // ensure an empty select clause + target[k] = { select: {} }; + } + this.doInjectReadCheckSelect(field.type, target[k], v); + } } } } + private makeAllScalarFieldSelect(model: string): any { + const fields = this.modelMeta.fields[lowerCaseFirst(model)]; + const result: any = {}; + if (fields) { + Object.entries(fields).forEach(([k, v]) => { + if (!v.isDataModel) { + result[k] = true; + } + }); + } + return result; + } + //#endregion //#region Errors @@ -767,6 +860,19 @@ export class PolicyUtil { return guard.readFieldSelect; } + checkReadField(model: string, field: string, entity: any) { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + const func = guard[`readFieldCheck$${field}`] as ReadFieldCheckFunc | undefined; + if (!func) { + return true; + } else { + return func(entity, { user: this.user }); + } + } + private hasFieldValidation(model: string): boolean { return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; } @@ -787,7 +893,12 @@ export class PolicyUtil { /** * Post processing checks and clean-up for read model entities. */ - postProcessForRead(data: any, queryArgs: any) { + postProcessForRead(data: any, model: string, queryArgs: any) { + const origData = this.clone(data); + this.doPostProcessForRead(data, model, origData, queryArgs); + } + + private doPostProcessForRead(data: any, model: string, fullData: any, queryArgs: any, path = '') { if (data === null || data === undefined) { return; } @@ -804,11 +915,55 @@ export class PolicyUtil { } } - for (const fieldData of Object.values(entityData)) { - if (typeof fieldData !== 'object' || !fieldData) { + for (const [field, fieldData] of Object.entries(entityData)) { + if (fieldData === undefined) { continue; } - this.postProcessForRead(fieldData); + + const fieldInfo = resolveField(this.modelMeta, model, field); + if (!fieldInfo) { + // could be _count, etc. + continue; + } + + if (!fieldInfo.isDataModel) { + // scalar field, delete unselected ones + const select = queryArgs?.select; + if (select && typeof select === 'object' && select[field] !== true) { + // there's a select clause but this field is not included + delete entityData[field]; + continue; + } + } else { + // relation field, delete if not included + const include = queryArgs?.include; + const select = queryArgs?.select; + if (!include?.[field] && !select?.[field]) { + // relation field not included or selected + delete entityData[field]; + continue; + } + } + + // delete unreadable fields + if (!this.checkReadField(model, field, fullData)) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] dropping unreadable field ${path ? path + '.' : ''}${field}`); + } + delete entityData[field]; + continue; + } + + if (fieldInfo.isDataModel) { + const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; + this.doPostProcessForRead( + fieldData, + fieldInfo.type, + fullData[field], + nextArgs, + path ? path + '.' + field : field + ); + } } } } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 55953f813..f227eb5ee 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -26,6 +26,11 @@ export type PolicyFunc = (context: QueryContext, db: Record boolean; +/** + * Function for getting policy guard with a given context + */ +export type ReadFieldCheckFunc = (input: any, context: QueryContext) => boolean; + /** * Policy definition */ @@ -41,7 +46,7 @@ export type PolicyDef = { } & { preValueSelect?: object; readFieldSelect?: object; - } + } & Record >; validation: Record; }; diff --git a/packages/schema/src/language-server/utils.ts b/packages/schema/src/language-server/utils.ts index 3a26d112b..c57836a91 100644 --- a/packages/schema/src/language-server/utils.ts +++ b/packages/schema/src/language-server/utils.ts @@ -1,33 +1,5 @@ -import { - DataModel, - DataModelField, - isArrayExpr, - isModel, - isReferenceExpr, - Model, - ReferenceExpr, -} from '@zenstackhq/language/ast'; +import { DataModel, DataModelField, isArrayExpr, isReferenceExpr, ReferenceExpr } from '@zenstackhq/language/ast'; import { resolved } from '@zenstackhq/sdk'; -import { AstNode } from 'langium'; -import { STD_LIB_MODULE_NAME } from './constants'; - -/** - * Gets the toplevel Model containing the given node. - */ -export function getContainingModel(node: AstNode | undefined): Model | null { - if (!node) { - return null; - } - return isModel(node) ? node : getContainingModel(node.$container); -} - -/** - * Returns if the given node is declared in stdlib. - */ -export function isFromStdlib(node: AstNode) { - const model = getContainingModel(node); - return !!model && !!model.$document && model.$document.uri.path.endsWith(STD_LIB_MODULE_NAME); -} /** * Gets lists of unique fields declared at the data model level diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts new file mode 100644 index 000000000..76024c236 --- /dev/null +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -0,0 +1,285 @@ +import { + ArrayExpr, + Attribute, + AttributeArg, + AttributeParam, + DataModelAttribute, + DataModelField, + DataModelFieldAttribute, + InternalAttribute, + ReferenceExpr, + isArrayExpr, + isAttribute, + isDataModel, + isDataModelField, + isEnum, + isReferenceExpr, +} from '@zenstackhq/language/ast'; +import { ValidationAcceptor, streamAllContents } from 'langium'; +import { AstValidator } from '../types'; +import pluralize from 'pluralize'; +import { isFutureExpr, resolved } from '@zenstackhq/sdk'; +import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; + +// a registry of function handlers marked with @func +const attributeCheckers = new Map(); + +// function handler decorator +function check(name: string) { + return function (_target: unknown, _propertyKey: string, descriptor: PropertyDescriptor) { + if (!attributeCheckers.get(name)) { + attributeCheckers.set(name, descriptor); + } + return descriptor; + }; +} + +type AttributeApplication = DataModelAttribute | DataModelFieldAttribute | InternalAttribute; + +/** + * Validates function declarations. + */ +export default class AttributeApplicationValidator implements AstValidator { + validate(attr: AttributeApplication, accept: ValidationAcceptor) { + const decl = attr.decl.ref; + if (!decl) { + return; + } + + const targetDecl = attr.$container; + if (decl.name === '@@@targetField' && !isAttribute(targetDecl)) { + accept('error', `attribute "${decl.name}" can only be used on attribute declarations`, { node: attr }); + return; + } + + if (isDataModelField(targetDecl) && !isValidAttributeTarget(decl, targetDecl)) { + accept('error', `attribute "${decl.name}" cannot be used on this type of field`, { node: attr }); + } + + const filledParams = new Set(); + + for (const arg of attr.args) { + let paramDecl: AttributeParam | undefined; + if (!arg.name) { + paramDecl = decl.params.find((p) => p.default && !filledParams.has(p)); + if (!paramDecl) { + accept('error', `Unexpected unnamed argument`, { + node: arg, + }); + return; + } + } else { + paramDecl = decl.params.find((p) => p.name === arg.name); + if (!paramDecl) { + accept('error', `Attribute "${decl.name}" doesn't have a parameter named "${arg.name}"`, { + node: arg, + }); + return; + } + } + + if (!assignableToAttributeParam(arg, paramDecl, attr)) { + accept('error', `Value is not assignable to parameter`, { + node: arg, + }); + return; + } + + if (filledParams.has(paramDecl)) { + accept('error', `Parameter "${paramDecl.name}" is already provided`, { node: arg }); + return; + } + filledParams.add(paramDecl); + arg.$resolvedParam = paramDecl; + } + + const missingParams = decl.params.filter((p) => !p.type.optional && !filledParams.has(p)); + if (missingParams.length > 0) { + accept( + 'error', + `Required ${pluralize('parameter', missingParams.length)} not provided: ${missingParams + .map((p) => p.name) + .join(', ')}`, + { node: attr } + ); + return; + } + + // run checkers for specific attributes + const checker = attributeCheckers.get(decl.name); + if (checker) { + checker.value.call(this, attr, accept); + } + } + + @check('@@allow') + @check('@@deny') + private _checkModelLevelPolicy(attr: AttributeApplication, accept: ValidationAcceptor) { + const kind = getStringLiteral(attr.args[0].value); + if (!kind) { + accept('error', `expects a string literal`, { node: attr.args[0] }); + return; + } + this.validatePolicyKinds(kind, ['create', 'read', 'update', 'delete', 'all'], attr, accept); + } + + @check('@allow') + @check('@deny') + private _checkFieldLevelPolicy(attr: AttributeApplication, accept: ValidationAcceptor) { + const kind = getStringLiteral(attr.args[0].value); + if (!kind) { + accept('error', `expects a string literal`, { node: attr.args[0] }); + return; + } + this.validatePolicyKinds(kind, ['read', 'update', 'all'], attr, accept); + + const expr = attr.args[1].value; + if ([expr, ...streamAllContents(expr)].some((node) => isFutureExpr(node))) { + accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr }); + } + } + + private validatePolicyKinds( + kind: string, + candidates: string[], + attr: AttributeApplication, + accept: ValidationAcceptor + ) { + const items = kind.split(',').map((x) => x.trim()); + items.forEach((item) => { + if (!candidates.includes(item)) { + accept( + 'error', + `Invalid policy rule kind: "${item}", allowed: ${candidates.map((c) => '"' + c + '"').join(', ')}`, + { node: attr } + ); + } + }); + } +} + +function assignableToAttributeParam(arg: AttributeArg, param: AttributeParam, attr: AttributeApplication): boolean { + const argResolvedType = arg.$resolvedType; + if (!argResolvedType) { + return false; + } + + let dstType = param.type.type; + let dstIsArray = param.type.array; + const dstRef = param.type.reference; + + if (dstType === 'Any' && !dstIsArray) { + return true; + } + + // destination is field reference or transitive field reference, check if + // argument is reference or array or reference + if (dstType === 'FieldReference' || dstType === 'TransitiveFieldReference') { + if (dstIsArray) { + return ( + isArrayExpr(arg.value) && + !arg.value.items.find((item) => !isReferenceExpr(item) || !isDataModelField(item.target.ref)) + ); + } else { + return isReferenceExpr(arg.value) && isDataModelField(arg.value.target.ref); + } + } + + if (isEnum(argResolvedType.decl)) { + // enum type + + let attrArgDeclType = dstRef?.ref; + if (dstType === 'ContextType' && isDataModelField(attr.$container) && attr.$container?.type?.reference) { + // attribute parameter type is ContextType, need to infer type from + // the attribute's container + attrArgDeclType = resolved(attr.$container.type.reference); + dstIsArray = attr.$container.type.array; + } + return attrArgDeclType === argResolvedType.decl && dstIsArray === argResolvedType.array; + } else if (dstType) { + // scalar type + + if (typeof argResolvedType?.decl !== 'string') { + // destination type is not a reference, so argument type must be a plain expression + return false; + } + + if (dstType === 'ContextType') { + // attribute parameter type is ContextType, need to infer type from + // the attribute's container + if (isDataModelField(attr.$container)) { + if (!attr.$container?.type?.type) { + return false; + } + dstType = mapBuiltinTypeToExpressionType(attr.$container.type.type); + dstIsArray = attr.$container.type.array; + } else { + dstType = 'Any'; + } + } + + return typeAssignable(dstType, argResolvedType.decl, arg.value) && dstIsArray === argResolvedType.array; + } else { + // reference type + return (dstRef?.ref === argResolvedType.decl || dstType === 'Any') && dstIsArray === argResolvedType.array; + } +} + +function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataModelField) { + const targetField = attrDecl.attributes.find((attr) => attr.decl.ref?.name === '@@@targetField'); + if (!targetField) { + // no field type constraint + return true; + } + + const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map( + (item) => (item as ReferenceExpr).target.ref?.name + ); + + let allowed = false; + for (const allowedType of fieldTypes) { + switch (allowedType) { + case 'StringField': + allowed = allowed || targetDecl.type.type === 'String'; + break; + case 'IntField': + allowed = allowed || targetDecl.type.type === 'Int'; + break; + case 'BigIntField': + allowed = allowed || targetDecl.type.type === 'BigInt'; + break; + case 'FloatField': + allowed = allowed || targetDecl.type.type === 'Float'; + break; + case 'DecimalField': + allowed = allowed || targetDecl.type.type === 'Decimal'; + break; + case 'BooleanField': + allowed = allowed || targetDecl.type.type === 'Boolean'; + break; + case 'DateTimeField': + allowed = allowed || targetDecl.type.type === 'DateTime'; + break; + case 'JsonField': + allowed = allowed || targetDecl.type.type === 'Json'; + break; + case 'BytesField': + allowed = allowed || targetDecl.type.type === 'Bytes'; + break; + case 'ModelField': + allowed = allowed || isDataModel(targetDecl.type.reference?.ref); + break; + default: + break; + } + if (allowed) { + break; + } + } + + return allowed; +} + +export function validateAttributeApplication(attr: AttributeApplication, accept: ValidationAcceptor) { + new AttributeApplicationValidator().validate(attr, accept); +} diff --git a/packages/schema/src/language-server/validator/attribute-validator.ts b/packages/schema/src/language-server/validator/attribute-validator.ts index b13791207..1bf961159 100644 --- a/packages/schema/src/language-server/validator/attribute-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-validator.ts @@ -1,11 +1,14 @@ import { Attribute } from '@zenstackhq/language/ast'; -import { AstValidator } from '../types'; import { ValidationAcceptor } from 'langium'; +import { AstValidator } from '../types'; +import { validateAttributeApplication } from './attribute-application-validator'; /** * Validates attribute declarations. */ export default class AttributeValidator implements AstValidator { // eslint-disable-next-line @typescript-eslint/no-unused-vars, @typescript-eslint/no-empty-function - validate(attr: Attribute, accept: ValidationAcceptor): void {} + validate(attr: Attribute, accept: ValidationAcceptor): void { + attr.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); + } } diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index a4329e44d..a26536af4 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -6,12 +6,13 @@ import { isLiteralExpr, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { analyzePolicies, getModelIdFields, getModelUniqueFields, getLiteral } from '@zenstackhq/sdk'; +import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields } from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; -import { validateAttributeApplication, validateDuplicatedDeclarations } from './utils'; +import { validateAttributeApplication } from './attribute-application-validator'; +import { validateDuplicatedDeclarations } from './utils'; /** * Validates data model declarations. @@ -94,9 +95,7 @@ export default class DataModelValidator implements AstValidator { } private validateAttributes(dm: DataModel, accept: ValidationAcceptor) { - dm.attributes.forEach((attr) => { - validateAttributeApplication(attr, accept); - }); + dm.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); } private parseRelation(field: DataModelField, accept?: ValidationAcceptor) { diff --git a/packages/schema/src/language-server/validator/enum-validator.ts b/packages/schema/src/language-server/validator/enum-validator.ts index 4453b2c12..4223d8a2b 100644 --- a/packages/schema/src/language-server/validator/enum-validator.ts +++ b/packages/schema/src/language-server/validator/enum-validator.ts @@ -1,7 +1,8 @@ import { Enum, EnumField } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; import { AstValidator } from '../types'; -import { validateAttributeApplication, validateDuplicatedDeclarations } from './utils'; +import { validateAttributeApplication } from './attribute-application-validator'; +import { validateDuplicatedDeclarations } from './utils'; /** * Validates enum declarations. @@ -17,14 +18,10 @@ export default class EnumValidator implements AstValidator { } private validateAttributes(_enum: Enum, accept: ValidationAcceptor) { - _enum.attributes.forEach((attr) => { - validateAttributeApplication(attr, accept); - }); + _enum.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); } private validateField(field: EnumField, accept: ValidationAcceptor) { - field.attributes.forEach((attr) => { - validateAttributeApplication(attr, accept); - }); + field.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); } } diff --git a/packages/schema/src/language-server/validator/function-decl-validator.ts b/packages/schema/src/language-server/validator/function-decl-validator.ts index a9438a9a0..9ef56c468 100644 --- a/packages/schema/src/language-server/validator/function-decl-validator.ts +++ b/packages/schema/src/language-server/validator/function-decl-validator.ts @@ -1,15 +1,13 @@ import { FunctionDecl } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; import { AstValidator } from '../types'; -import { validateAttributeApplication } from './utils'; +import { validateAttributeApplication } from './attribute-application-validator'; /** * Validates function declarations. */ export default class FunctionDeclValidator implements AstValidator { validate(funcDecl: FunctionDecl, accept: ValidationAcceptor) { - funcDecl.attributes.forEach((attr) => { - validateAttributeApplication(attr, accept); - }); + funcDecl.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); } } diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index b6c321fec..3ccf5a260 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -14,10 +14,9 @@ import { import { AstNode, ValidationAcceptor } from 'langium'; import { getDataModelFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; -import { isFromStdlib } from '../utils'; import { typeAssignable } from './utils'; import { match, P } from 'ts-pattern'; -import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference } from '@zenstackhq/sdk'; +import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; /** * InvocationExpr validation diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index fa8de9664..9bcb42110 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -1,6 +1,4 @@ import { - ArrayExpr, - Attribute, AttributeArg, AttributeParam, BuiltinType, @@ -11,17 +9,13 @@ import { ExpressionType, InternalAttribute, isArrayExpr, - isAttribute, - isDataModel, isDataModelField, isEnum, isLiteralExpr, isReferenceExpr, - ReferenceExpr, } from '@zenstackhq/language/ast'; import { resolved } from '@zenstackhq/sdk'; import { AstNode, ValidationAcceptor } from 'langium'; -import pluralize from 'pluralize'; /** * Checks if the given declarations have duplicated names @@ -191,129 +185,3 @@ export function assignableToAttributeParam( return (dstRef?.ref === argResolvedType.decl || dstType === 'Any') && dstIsArray === argResolvedType.array; } } - -export function validateAttributeApplication( - attr: DataModelAttribute | DataModelFieldAttribute | InternalAttribute, - accept: ValidationAcceptor -) { - const decl = attr.decl.ref; - if (!decl) { - return; - } - - const targetDecl = attr.$container; - if (decl.name === '@@@targetField' && !isAttribute(targetDecl)) { - accept('error', `attribute "${decl.name}" can only be used on attribute declarations`, { node: attr }); - return; - } - - if (isDataModelField(targetDecl) && !isValidAttributeTarget(decl, targetDecl)) { - accept('error', `attribute "${decl.name}" cannot be used on this type of field`, { node: attr }); - } - - const filledParams = new Set(); - - for (const arg of attr.args) { - let paramDecl: AttributeParam | undefined; - if (!arg.name) { - paramDecl = decl.params.find((p) => p.default && !filledParams.has(p)); - if (!paramDecl) { - accept('error', `Unexpected unnamed argument`, { - node: arg, - }); - return false; - } - } else { - paramDecl = decl.params.find((p) => p.name === arg.name); - if (!paramDecl) { - accept('error', `Attribute "${decl.name}" doesn't have a parameter named "${arg.name}"`, { - node: arg, - }); - return false; - } - } - - if (!assignableToAttributeParam(arg, paramDecl, attr)) { - accept('error', `Value is not assignable to parameter`, { - node: arg, - }); - return false; - } - - if (filledParams.has(paramDecl)) { - accept('error', `Parameter "${paramDecl.name}" is already provided`, { node: arg }); - return false; - } - filledParams.add(paramDecl); - arg.$resolvedParam = paramDecl; - } - - const missingParams = decl.params.filter((p) => !p.type.optional && !filledParams.has(p)); - if (missingParams.length > 0) { - accept( - 'error', - `Required ${pluralize('parameter', missingParams.length)} not provided: ${missingParams - .map((p) => p.name) - .join(', ')}`, - { node: attr } - ); - return false; - } - - return true; -} - -function isValidAttributeTarget(attrDecl: Attribute, targetDecl: DataModelField) { - const targetField = attrDecl.attributes.find((attr) => attr.decl.ref?.name === '@@@targetField'); - if (!targetField) { - // no field type constraint - return true; - } - - const fieldTypes = (targetField.args[0].value as ArrayExpr).items.map( - (item) => (item as ReferenceExpr).target.ref?.name - ); - - let allowed = false; - for (const allowedType of fieldTypes) { - switch (allowedType) { - case 'StringField': - allowed = allowed || targetDecl.type.type === 'String'; - break; - case 'IntField': - allowed = allowed || targetDecl.type.type === 'Int'; - break; - case 'BigIntField': - allowed = allowed || targetDecl.type.type === 'BigInt'; - break; - case 'FloatField': - allowed = allowed || targetDecl.type.type === 'Float'; - break; - case 'DecimalField': - allowed = allowed || targetDecl.type.type === 'Decimal'; - break; - case 'BooleanField': - allowed = allowed || targetDecl.type.type === 'Boolean'; - break; - case 'DateTimeField': - allowed = allowed || targetDecl.type.type === 'DateTime'; - break; - case 'JsonField': - allowed = allowed || targetDecl.type.type === 'Json'; - break; - case 'BytesField': - allowed = allowed || targetDecl.type.type === 'Bytes'; - break; - case 'ModelField': - allowed = allowed || isDataModel(targetDecl.type.reference?.ref); - break; - default: - break; - } - if (allowed) { - break; - } - } - - return allowed; -} diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 4a8557bdd..7b9d42956 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -45,8 +45,8 @@ import { } from 'langium'; import { CancellationToken } from 'vscode-jsonrpc'; import { getAllDeclarationsFromImports } from '../utils/ast-utils'; -import { getContainingModel, isFromStdlib } from './utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; +import { getContainingModel, isFromStdlib } from '@zenstackhq/sdk'; interface DefaultReference extends Reference { _ref?: AstNode | LinkingError; diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index d27b9ee55..24c9d6b6d 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -18,6 +18,7 @@ import { getFunctionExpressionContext, getLiteral, isDataModelFieldReference, + isFutureExpr, PluginError, } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; @@ -28,7 +29,6 @@ import { TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, } from '../../utils/typescript-expression-transformer'; -import { isFutureExpr } from './utils'; type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<='; diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 006d1a835..554e0b759 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -32,6 +32,8 @@ import { hasAttribute, hasValidationAttributes, isForeignKeyField, + isFromStdlib, + isFutureExpr, resolvePath, resolved, saveProject, @@ -47,7 +49,6 @@ import { WriterFunction, } from 'ts-morph'; import { name } from '.'; -import { isFromStdlib } from '../../language-server/utils'; import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; import { TypeScriptExpressionTransformer, @@ -55,7 +56,6 @@ import { } from '../../utils/typescript-expression-transformer'; import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; -import { isFutureExpr } from './utils'; /** * Generates source file that contains Prisma query guard objects used for injecting database queries @@ -264,31 +264,44 @@ export default class PolicyGenerator { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion result[kind + '_input'] = inputCheckFunc.getName()!; } + } - const allFieldsAllows: Expression[] = []; - const allFieldsDenies: Expression[] = []; + // generate field read checkers + this.generateReadFieldsGuards(model, sourceFile, result); - for (const field of model.fields) { - const allows = this.getPolicyExpressions(field, 'allow', 'read'); - const denies = this.getPolicyExpressions(field, 'deny', 'read'); - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); + // generate field update guards + this.generateUpdateFieldsGuards(model, sourceFile, result); - if (denies.length === 0 && allows.length === 0) { - continue; - } + return result; + } - const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`readFieldCheck$${field.name}`] = guardFunc.getName()!; - } + private generateReadFieldsGuards( + model: DataModel, + sourceFile: SourceFile, + result: Record + ) { + const allFieldsAllows: Expression[] = []; + const allFieldsDenies: Expression[] = []; - const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); - if (readFieldCheckSelect) { - result[`readFieldSelect`] = readFieldCheckSelect; + for (const field of model.fields) { + const allows = this.getPolicyExpressions(field, 'allow', 'read'); + const denies = this.getPolicyExpressions(field, 'deny', 'read'); + allFieldsAllows.push(...allows); + allFieldsDenies.push(...denies); + + if (denies.length === 0 && allows.length === 0) { + continue; } + + const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[`readFieldCheck$${field.name}`] = guardFunc.getName()!; + } + + const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); + if (readFieldCheckSelect) { + result[`readFieldSelect`] = readFieldCheckSelect; } - return result; } private generateReadFieldGuardFunction( @@ -371,6 +384,25 @@ export default class PolicyGenerator { return func; } + private generateUpdateFieldsGuards( + model: DataModel, + sourceFile: SourceFile, + result: Record + ) { + for (const field of model.fields) { + const allows = this.getPolicyExpressions(field, 'allow', 'update'); + const denies = this.getPolicyExpressions(field, 'deny', 'update'); + + if (denies.length === 0 && allows.length === 0) { + continue; + } + + const guardFunc = this.generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[`updateFieldGuard$${field.name}`] = guardFunc.getName()!; + } + } + private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { return [...allows, ...denies].every((rule) => { return [...this.allNodes(rule)].every((expr) => { @@ -471,7 +503,7 @@ export default class PolicyGenerator { } } - return Object.keys(result).length === 0 ? null : result; + return Object.keys(result).length === 0 ? undefined : result; } private generateQueryGuardFunction( @@ -479,7 +511,8 @@ export default class PolicyGenerator { model: DataModel, kind: PolicyOperationKind, allows: Expression[], - denies: Expression[] + denies: Expression[], + forField?: DataModelField ): FunctionDeclaration { const statements: (string | WriterFunction | StatementStructures)[] = []; @@ -588,7 +621,7 @@ export default class PolicyGenerator { } const func = sourceFile.addFunction({ - name: model.name + '_' + kind, + name: `${model.name}${forField ? '$' + forField.name : ''}_${kind}`, returnType: 'any', parameters: [ { @@ -596,6 +629,7 @@ export default class PolicyGenerator { type: 'QueryContext', }, { + // for generating field references used by field comparison in the same model name: 'db', type: 'Record', }, diff --git a/packages/schema/src/plugins/access-policy/utils.ts b/packages/schema/src/plugins/access-policy/utils.ts deleted file mode 100644 index 816386a6e..000000000 --- a/packages/schema/src/plugins/access-policy/utils.ts +++ /dev/null @@ -1,10 +0,0 @@ -import { isInvocationExpr } from '@zenstackhq/language/ast'; -import { AstNode } from 'langium/lib/syntax-tree'; -import { isFromStdlib } from '../../language-server/utils'; - -/** - * Returns if the given expression is a "future()" method call. - */ -export function isFutureExpr(node: AstNode) { - return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); -} diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 9348650c2..678952678 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -10,6 +10,7 @@ import { hasAttribute, isEnumFieldReference, isForeignKeyField, + isFromStdlib, resolvePath, saveProject, } from '@zenstackhq/sdk'; @@ -20,7 +21,6 @@ import { streamAllContents } from 'langium'; import path from 'path'; import { Project } from 'ts-morph'; import { upperCaseFirst } from 'upper-case-first'; -import { isFromStdlib } from '../../language-server/utils'; import { getDefaultOutputFolder } from '../plugin-utils'; import Transformer from './transformer'; import removeDir from './utils/removeDir'; diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index a9a892873..84431274c 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -164,7 +164,7 @@ function isEmpty(field: Any[]): Boolean { /** * Marks an attribute to be only applicable to certain field types. */ -attribute @@@targetField(targetField: AttributeTargetField[]) +attribute @@@targetField(_ targetField: AttributeTargetField[]) /** * Marks an attribute to be used for data validation. diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index c6ea1545f..b2c0771be 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -13,9 +13,9 @@ import { ModelImport, ReferenceExpr, } from '@zenstackhq/language/ast'; +import { isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; import { URI, Utils } from 'vscode-uri'; -import { isFromStdlib } from '../language-server/utils'; export function extractDataModelsWithAllowRules(model: Model): DataModel[] { return model.declarations.filter( diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index fb9ac41cf..622b62c06 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -12,9 +12,7 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getLiteral } from '@zenstackhq/sdk'; -import { isFromStdlib } from '../language-server/utils'; -import { isFutureExpr } from '../plugins/access-policy/utils'; +import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; export class TypeScriptExpressionTransformerError extends Error { constructor(message: string) { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 448331fd7..889bdc910 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -942,4 +942,58 @@ describe('Attribute tests', () => { `) ).toContain('function "search" is not allowed in the current context: ValidationRule'); }); + + it('invalid policy rule kind', async () => { + expect( + await loadModelWithError(` + ${prelude} + model M { + id String @id + x Int + @@allow('read,foo', x > 0) + } + `) + ).toContain('Invalid policy rule kind: "foo", allowed: "create", "read", "update", "delete", "all"'); + + expect( + await loadModelWithError(` + ${prelude} + model M { + id String @id + x Int + @@deny('update,foo', x > 0) + } + `) + ).toContain('Invalid policy rule kind: "foo", allowed: "create", "read", "update", "delete", "all"'); + + expect( + await loadModelWithError(` + ${prelude} + model M { + id String @id + x Int @allow('foo', x > 0) + } + `) + ).toContain('Invalid policy rule kind: "foo", allowed: "read", "update", "all"'); + + expect( + await loadModelWithError(` + ${prelude} + model M { + id String @id + x Int @deny('foo', x < 0) + } + `) + ).toContain('Invalid policy rule kind: "foo", allowed: "read", "update", "all"'); + + expect( + await loadModelWithError(` + ${prelude} + model M { + id String @id + x Int @allow('update', future().x > 0) + } + `) + ).toContain('"future()" is not allowed in field-level policy rules'); + }); }); diff --git a/packages/sdk/src/constants.ts b/packages/sdk/src/constants.ts index bba0ab93b..08fd6cc18 100644 --- a/packages/sdk/src/constants.ts +++ b/packages/sdk/src/constants.ts @@ -13,3 +13,5 @@ export enum ExpressionContext { AccessPolicy = 'AccessPolicy', ValidationRule = 'ValidationRule', } + +export const STD_LIB_MODULE_NAME = 'stdlib.zmodel'; diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 9543b3242..fe0e37b4c 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -13,7 +13,9 @@ import { isDataModel, isDataModelField, isEnumField, + isInvocationExpr, isLiteralExpr, + isModel, isObjectExpr, isReferenceExpr, Model, @@ -21,7 +23,7 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import path from 'path'; -import { ExpressionContext } from './constants'; +import { ExpressionContext, STD_LIB_MODULE_NAME } from './constants'; import { PluginOptions } from './types'; /** @@ -280,3 +282,19 @@ export function getFunctionExpressionContext(funcDecl: FunctionDecl) { } return funcAllowedContext; } + +export function isFutureExpr(node: AstNode) { + return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); +} + +export function isFromStdlib(node: AstNode) { + const model = getContainingModel(node); + return !!model && !!model.$document && model.$document.uri.path.endsWith(STD_LIB_MODULE_NAME); +} + +export function getContainingModel(node: AstNode | undefined): Model | null { + if (!node) { + return null; + } + return isModel(node) ? node : getContainingModel(node.$container); +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index bed36a56f..ab721f394 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -192,7 +192,7 @@ importers: version: 2.0.3(react@18.2.0) ts-jest: specifier: ^29.0.5 - version: 29.0.5(@babel/core@7.22.5)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) + version: 29.0.5(@babel/core@7.22.9)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) typescript: specifier: ^4.9.4 version: 4.9.4 @@ -266,7 +266,7 @@ importers: version: 2.0.3(react@18.2.0) ts-jest: specifier: ^29.0.5 - version: 29.0.5(@babel/core@7.22.5)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) + version: 29.0.5(@babel/core@7.22.9)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) typescript: specifier: ^4.9.4 version: 4.9.4 @@ -331,13 +331,13 @@ importers: version: 29.5.0(@types/node@18.0.0) next: specifier: ^13.4.7 - version: 13.4.7(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0) + version: 13.4.7(@babel/core@7.22.5)(react-dom@18.2.0)(react@18.2.0) rimraf: specifier: ^3.0.2 version: 3.0.2 ts-jest: specifier: ^29.0.5 - version: 29.0.5(@babel/core@7.22.9)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) + version: 29.0.5(@babel/core@7.22.5)(esbuild@0.18.13)(jest@29.5.0)(typescript@4.9.4) typescript: specifier: ^4.9.4 version: 4.9.4 @@ -3453,7 +3453,7 @@ packages: '@trpc/client': 10.32.0(@trpc/server@10.32.0) '@trpc/react-query': 10.32.0(@tanstack/react-query@4.29.7)(@trpc/client@10.32.0)(@trpc/server@10.32.0)(react-dom@18.2.0)(react@18.2.0) '@trpc/server': 10.32.0 - next: 13.4.7(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0) + next: 13.4.7(@babel/core@7.22.5)(react-dom@18.2.0)(react@18.2.0) react: 18.2.0 react-dom: 18.2.0(react@18.2.0) react-ssr-prepass: 1.5.0(react@18.2.0) @@ -8664,7 +8664,7 @@ packages: - babel-plugin-macros dev: true - /next@13.4.7(@babel/core@7.22.9)(react-dom@18.2.0)(react@18.2.0): + /next@13.4.7(@babel/core@7.22.5)(react-dom@18.2.0)(react@18.2.0): resolution: {integrity: sha512-M8z3k9VmG51SRT6v5uDKdJXcAqLzP3C+vaKfLIAM0Mhx1um1G7MDnO63+m52qPdZfrTFzMZNzfsgvm3ghuVHIQ==} engines: {node: '>=16.8.0'} hasBin: true @@ -8689,7 +8689,7 @@ packages: postcss: 8.4.14 react: 18.2.0 react-dom: 18.2.0(react@18.2.0) - styled-jsx: 5.1.1(@babel/core@7.22.9)(react@18.2.0) + styled-jsx: 5.1.1(@babel/core@7.22.5)(react@18.2.0) watchpack: 2.4.0 zod: 3.21.4 optionalDependencies: @@ -10368,6 +10368,24 @@ packages: react: 18.2.0 dev: true + /styled-jsx@5.1.1(@babel/core@7.22.5)(react@18.2.0): + resolution: {integrity: sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==} + engines: {node: '>= 12.0.0'} + peerDependencies: + '@babel/core': '*' + babel-plugin-macros: '*' + react: '>= 16.8.0 || 17.x.x || ^18.0.0-0' + peerDependenciesMeta: + '@babel/core': + optional: true + babel-plugin-macros: + optional: true + dependencies: + '@babel/core': 7.22.5 + client-only: 0.0.1 + react: 18.2.0 + dev: true + /styled-jsx@5.1.1(@babel/core@7.22.9)(react@18.2.0): resolution: {integrity: sha512-pW7uC1l4mBZ8ugbiZrcIsiIvVx1UmTfw7UkC3Um2tmfUq9Bhk8IiyEIPl6F8agHgjzku6j0xQEZbfA5uSgSaCw==} engines: {node: '>= 12.0.0'} diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index cfd5faa82..49c562a3d 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -12,34 +12,701 @@ describe('With Policy: field-level policy', () => { process.chdir(origDir); }); - it('read', async () => { + it('read simple', async () => { const { prisma, withPolicy } = await loadSchema( ` model User { id Int @id @default(autoincrement()) admin Boolean @default(false) + models Model[] + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('read', x > 0) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ data: { id: 1, admin: true } }); + + const db = withPolicy(); + let r; + + // y is unreadable + + r = await db.model.create({ + data: { + id: 1, + x: 0, + y: 0, + ownerId: 1, + }, + }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.owner).toBeTruthy(); + expect(r.y).toBeUndefined(); + + // y is readable + + r = await db.model.create({ + data: { + id: 2, + x: 1, + y: 0, + ownerId: 1, + }, + }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 2 } }); + expect(r.x).toEqual(1); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + expect(r.owner).toBeTruthy(); + }); + + it('read filter with auth', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + admin Boolean @default(false) + models Model[] + + @@allow('all', true) } model Model { id Int @id @default(autoincrement()) x Int - y Int @allow('read', x > 0 || auth().admin) + y Int @allow('read', auth().admin) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + `, + { logPrismaQuery: true } + ); + + await prisma.user.create({ data: { id: 1, admin: true } }); + + let db = withPolicy({ id: 1, admin: false }); + let r; + + // y is unreadable + + r = await db.model.create({ + data: { + id: 1, + x: 0, + y: 0, + ownerId: 1, + }, + }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.owner).toBeTruthy(); + expect(r.y).toBeUndefined(); + + // y is readable + db = withPolicy({ id: 1, admin: true }); + r = await db.model.create({ + data: { + id: 2, + x: 1, + y: 0, + ownerId: 1, + }, + }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 2 } }); + expect(r.x).toEqual(1); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + expect(r.owner).toBeTruthy(); + }); + + it('read filter with relation', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + admin Boolean @default(false) + models Model[] + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('read', owner.admin) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int @@allow('all', true) } ` ); - await prisma.model.create({ + await prisma.user.create({ data: { id: 1, admin: false } }); + await prisma.user.create({ data: { id: 2, admin: true } }); + + const db = withPolicy(); + let r; + + // y is unreadable + + r = await db.model.create({ data: { id: 1, x: 0, y: 0, + ownerId: 1, + }, + }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 1 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 1 } }); + expect(r.x).toEqual(0); + expect(r.owner).toBeTruthy(); + expect(r.y).toBeUndefined(); + + // y is readable + r = await db.model.create({ + data: { + id: 2, + x: 1, + y: 0, + ownerId: 2, }, }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ select: { x: true }, where: { id: 2 } }); + expect(r.x).toEqual(1); + expect(r.y).toBeUndefined(); + + r = await db.model.findUnique({ select: { y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: false, y: true }, where: { id: 2 } }); + expect(r.x).toBeUndefined(); + expect(r.y).toEqual(0); + + r = await db.model.findUnique({ select: { x: true, y: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + + r = await db.model.findUnique({ include: { owner: true }, where: { id: 2 } }); + expect(r).toEqual(expect.objectContaining({ x: 1, y: 0 })); + expect(r.owner).toBeTruthy(); + }); + + it('read coverage', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('read', x > 0) + + @@allow('all', true) + } + ` + ); const db = withPolicy(); - const r = await db.model.findUnique({ where: { id: 1 } }); + let r; + + // y is unreadable + + r = await db.model.create({ + data: { + id: 1, + x: 0, + y: 0, + }, + }); + + r = await db.model.findUnique({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findUniqueOrThrow({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findFirst({ where: { id: 1 } }); expect(r.y).toBeUndefined(); + + r = await db.model.findFirstOrThrow({ where: { id: 1 } }); + expect(r.y).toBeUndefined(); + + r = await db.model.findMany({ where: { id: 1 } }); + expect(r[0].y).toBeUndefined(); + }); + + it('update simple', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', x > 0) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('create,read', true) + @@allow('update', y > 0) + } + ` + ); + + await prisma.user.create({ + data: { id: 1 }, + }); + const db = withPolicy(); + + await db.model.create({ + data: { id: 1, x: 0, y: 0, ownerId: 1 }, + }); + await expect( + db.model.update({ + where: { id: 1 }, + data: { y: 2 }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.model.update({ + where: { id: 1 }, + data: { x: 2 }, + }) + ).toBeRejectedByPolicy(); + + await db.model.create({ + data: { id: 2, x: 0, y: 1, ownerId: 1 }, + }); + await expect( + db.model.update({ + where: { id: 2 }, + data: { y: 2 }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.model.update({ + where: { id: 2 }, + data: { x: 2 }, + }) + ).toResolveTruthy(); + + await db.model.create({ + data: { id: 3, x: 1, y: 1, ownerId: 1 }, + }); + await expect( + db.model.update({ + where: { id: 3 }, + data: { y: 2 }, + }) + ).toResolveTruthy(); + }); + + it('update filter with relation', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + admin Boolean @default(false) + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', owner.admin) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { id: 1, admin: false }, + }); + await prisma.user.create({ + data: { id: 2, admin: true }, + }); + const db = withPolicy(); + + await db.model.create({ + data: { id: 1, x: 0, y: 0, ownerId: 1 }, + }); + await expect( + db.model.update({ + where: { id: 1 }, + data: { y: 2 }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.model.update({ + where: { id: 1 }, + data: { x: 2 }, + }) + ).toResolveTruthy(); + + await db.model.create({ + data: { id: 2, x: 0, y: 0, ownerId: 2 }, + }); + await expect( + db.model.update({ + where: { id: 2 }, + data: { y: 2 }, + }) + ).toResolveTruthy(); + }); + + it('update to-many relation', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + admin Boolean @default(false) + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', owner.admin) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { id: 1, admin: false, models: { create: { id: 1, x: 0, y: 0 } } }, + }); + await prisma.user.create({ + data: { id: 2, admin: true, models: { create: { id: 2, x: 0, y: 0 } } }, + }); + const db = withPolicy(); + + await expect( + db.user.update({ + where: { id: 1 }, + data: { models: { update: { where: { id: 1 }, data: { y: 2 } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { models: { update: { where: { id: 1 }, data: { x: 2 } } } }, + }) + ).toResolveTruthy(); + + await expect( + db.user.update({ + where: { id: 2 }, + data: { models: { update: { where: { id: 2 }, data: { y: 2 } } } }, + }) + ).toResolveTruthy(); + }); + + it('update to-one relation', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + model Model? + admin Boolean @default(false) + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', owner.admin) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int @unique + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { id: 1, admin: false, model: { create: { id: 1, x: 0, y: 0 } } }, + }); + await prisma.user.create({ + data: { id: 2, admin: true, model: { create: { id: 2, x: 0, y: 0 } } }, + }); + const db = withPolicy(); + + await expect( + db.user.update({ + where: { id: 1 }, + data: { model: { update: { data: { y: 2 } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { model: { update: { y: 2 } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { model: { update: { data: { x: 2 } } } }, + }) + ).toResolveTruthy(); + await expect( + db.user.update({ + where: { id: 1 }, + data: { model: { update: { x: 2 } } }, + }) + ).toResolveTruthy(); + + await expect( + db.user.update({ + where: { id: 2 }, + data: { model: { update: { data: { y: 2 } } } }, + }) + ).toResolveTruthy(); + await expect( + db.user.update({ + where: { id: 2 }, + data: { model: { update: { y: 2 } } }, + }) + ).toResolveTruthy(); + }); + + it('updateMany simple', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', x > 0) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + models: { + create: [ + { id: 1, x: 0, y: 0 }, + { id: 2, x: 1, y: 0 }, + ], + }, + }, + }); + const db = withPolicy(); + + await expect(db.model.updateMany({ data: { y: 2 } })).resolves.toEqual({ count: 1 }); + await expect(db.model.findUnique({ where: { id: 1 } })).resolves.toEqual( + expect.objectContaining({ x: 0, y: 0 }) + ); + await expect(db.model.findUnique({ where: { id: 2 } })).resolves.toEqual( + expect.objectContaining({ x: 1, y: 2 }) + ); + }); + + it('updateMany nested', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + models Model[] + + @@allow('all', true) + } + + model Model { + id Int @id @default(autoincrement()) + x Int + y Int @allow('update', x > 0) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + + @@allow('all', true) + } + ` + ); + + await prisma.user.create({ + data: { + id: 1, + models: { + create: [ + { id: 1, x: 0, y: 0 }, + { id: 2, x: 1, y: 0 }, + ], + }, + }, + }); + const db = withPolicy(); + + await expect( + db.user.update({ where: { id: 1 }, data: { models: { updateMany: { data: { y: 2 } } } } }) + ).toResolveTruthy(); + await expect(db.model.findUnique({ where: { id: 1 } })).resolves.toEqual( + expect.objectContaining({ x: 0, y: 0 }) + ); + await expect(db.model.findUnique({ where: { id: 2 } })).resolves.toEqual( + expect.objectContaining({ x: 1, y: 2 }) + ); + + await expect( + db.user.update({ where: { id: 1 }, data: { models: { updateMany: { where: { id: 1 }, data: { y: 2 } } } } }) + ).toResolveTruthy(); + await expect(db.model.findUnique({ where: { id: 1 } })).resolves.toEqual( + expect.objectContaining({ x: 0, y: 0 }) + ); + + await expect( + db.user.update({ where: { id: 1 }, data: { models: { updateMany: { where: { id: 2 }, data: { y: 3 } } } } }) + ).toResolveTruthy(); + await expect(db.model.findUnique({ where: { id: 2 } })).resolves.toEqual( + expect.objectContaining({ x: 1, y: 3 }) + ); }); }); diff --git a/tests/integration/tsconfig.json b/tests/integration/tsconfig.json index babfdfdee..2771cd805 100644 --- a/tests/integration/tsconfig.json +++ b/tests/integration/tsconfig.json @@ -5,7 +5,8 @@ "esModuleInterop": true, "forceConsistentCasingInFileNames": true, "strict": true, - "skipLibCheck": true + "skipLibCheck": true, + "experimentalDecorators": true }, "include": ["**/*.ts", "**/*.d.ts"] } From 566e113a03f10fe124697e38cd5150c1843d0c84 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 19 Aug 2023 09:46:31 +0800 Subject: [PATCH 3/6] wip --- .../src/enhancements/policy/policy-utils.ts | 8 ++++---- packages/runtime/src/enhancements/utils.ts | 20 +++++++++++++++++++ .../with-policy/field-level-policy.test.ts | 11 ++++++++-- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index bff46ae13..58df66bab 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -11,13 +11,13 @@ import { getFields, resolveField } from '../model-meta'; import { NestedWriteVisitorContext } from '../nested-write-vistor'; import type { InputCheckFunc, ModelMeta, PolicyDef, PolicyFunc, ReadFieldCheckFunc, ZodSchemas } from '../types'; import { - enumerate, formatObject, getIdFields, getModelFields, prismaClientKnownRequestError, prismaClientUnknownRequestError, prismaClientValidationError, + zip, } from '../utils'; import { Logger } from './logger'; @@ -903,7 +903,7 @@ export class PolicyUtil { return; } - for (const entityData of enumerate(data)) { + for (const [entityData, entityFullData] of zip(data, fullData)) { if (typeof entityData !== 'object' || !entityData) { return; } @@ -946,7 +946,7 @@ export class PolicyUtil { } // delete unreadable fields - if (!this.checkReadField(model, field, fullData)) { + if (!this.checkReadField(model, field, entityFullData)) { if (this.shouldLogQuery) { this.logger.info(`[policy] dropping unreadable field ${path ? path + '.' : ''}${field}`); } @@ -959,7 +959,7 @@ export class PolicyUtil { this.doPostProcessForRead( fieldData, fieldInfo.type, - fullData[field], + entityFullData[field], nextArgs, path ? path + '.' + field : field ); diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index e2d286da0..c166672b3 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -51,6 +51,26 @@ export function enumerate(x: Enumerable) { } } +/** + * Zip two arrays or scalars. + */ +export function zip(x: Enumerable, y: Enumerable): Array<[T1, T2]> { + if (Array.isArray(x)) { + if (!Array.isArray(y)) { + throw new Error('x and y should be both array or both scalar'); + } + if (x.length !== y.length) { + throw new Error('x and y should have the same length'); + } + return x.map((_, i) => [x[i], y[i]] as [T1, T2]); + } else { + if (Array.isArray(y)) { + throw new Error('x and y should be both array or both scalar'); + } + return [[x, y]]; + } +} + /** * Formats an object for pretty printing. */ diff --git a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts index 49c562a3d..cc0eda275 100644 --- a/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-level-policy.test.ts @@ -330,7 +330,6 @@ describe('With Policy: field-level policy', () => { let r; // y is unreadable - r = await db.model.create({ data: { id: 1, @@ -351,8 +350,16 @@ describe('With Policy: field-level policy', () => { r = await db.model.findFirstOrThrow({ where: { id: 1 } }); expect(r.y).toBeUndefined(); - r = await db.model.findMany({ where: { id: 1 } }); + await db.model.create({ + data: { + id: 2, + x: 1, + y: 0, + }, + }); + r = await db.model.findMany({ where: { x: { gte: 0 } } }); expect(r[0].y).toBeUndefined(); + expect(r[1].y).toEqual(0); }); it('update simple', async () => { From a57d9548cf093e71f9ec15bc1d0f383ccacb2244 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 19 Aug 2023 11:21:22 +0800 Subject: [PATCH 4/6] wip --- packages/runtime/src/constants.ts | 20 +++++ .../src/enhancements/policy/handler.ts | 20 +++-- .../src/enhancements/policy/policy-utils.ts | 88 +++++++++++++------ packages/runtime/src/enhancements/types.ts | 13 ++- .../access-policy/policy-guard-generator.ts | 19 ++-- 5 files changed, 119 insertions(+), 41 deletions(-) diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index fb3644e60..ad94a367f 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -72,3 +72,23 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; * Minimum Prisma version supported */ export const PRISMA_MINIMUM_VERSION = '4.8.0'; + +/** + * Selector function name for fetching pre-update value of entities. + */ +export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect'; + +/** + * Prefix for field-level access control guard function name + */ +export const FIELD_LEVEL_POLICY_GUARD_PREFIX = 'readFieldCheck$'; + +/** + * Field-level access control evaluation selector function name + */ +export const FIELD_LEVEL_POLICY_GUARD_SELECTOR = 'readFieldSelect'; + +/** + * Flag that indicates if the model has field-level access control + */ +export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy'; diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 78a5d1400..476573a62 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -70,7 +70,9 @@ export class PolicyProxyHandler implements Pr return null; } - this.utils.injectReadCheckSelect(this.model, args); + if (this.utils.hasFieldLevelPolicy(this.model)) { + this.utils.injectReadCheckSelect(this.model, args); + } if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); @@ -94,7 +96,9 @@ export class PolicyProxyHandler implements Pr throw this.utils.notFound(this.model); } - this.utils.injectReadCheckSelect(this.model, args); + if (this.utils.hasFieldLevelPolicy(this.model)) { + this.utils.injectReadCheckSelect(this.model, args); + } if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`); @@ -111,7 +115,9 @@ export class PolicyProxyHandler implements Pr return null; } - this.utils.injectReadCheckSelect(this.model, args); + if (this.utils.hasFieldLevelPolicy(this.model)) { + this.utils.injectReadCheckSelect(this.model, args); + } if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`); @@ -128,7 +134,9 @@ export class PolicyProxyHandler implements Pr throw this.utils.notFound(this.model); } - this.utils.injectReadCheckSelect(this.model, args); + if (this.utils.hasFieldLevelPolicy(this.model)) { + this.utils.injectReadCheckSelect(this.model, args); + } if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`); @@ -145,7 +153,9 @@ export class PolicyProxyHandler implements Pr return []; } - this.utils.injectReadCheckSelect(this.model, args); + if (this.utils.hasFieldLevelPolicy(this.model)) { + this.utils.injectReadCheckSelect(this.model, args); + } if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 58df66bab..ce6e7a346 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -4,7 +4,15 @@ import deepcopy from 'deepcopy'; import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; -import { AUXILIARY_FIELDS, CrudFailureReason, PrismaErrorCode } from '../../constants'; +import { + AUXILIARY_FIELDS, + CrudFailureReason, + FIELD_LEVEL_POLICY_GUARD_PREFIX, + FIELD_LEVEL_POLICY_GUARD_SELECTOR, + HAS_FIELD_LEVEL_POLICY_FLAG, + PRE_UPDATE_VALUE_SELECTOR, + PrismaErrorCode, +} from '../../constants'; import { AuthUser, DbClientContract, DbOperations, FieldInfo, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import { getFields, resolveField } from '../model-meta'; @@ -849,7 +857,7 @@ export class PolicyUtil { if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - return guard.preValueSelect; + return guard[PRE_UPDATE_VALUE_SELECTOR]; } getReadFieldSelect(model: string): object | undefined { @@ -857,7 +865,7 @@ export class PolicyUtil { if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - return guard.readFieldSelect; + return guard[FIELD_LEVEL_POLICY_GUARD_SELECTOR]; } checkReadField(model: string, field: string, entity: any) { @@ -865,7 +873,7 @@ export class PolicyUtil { if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - const func = guard[`readFieldCheck$${field}`] as ReadFieldCheckFunc | undefined; + const func = guard[`${FIELD_LEVEL_POLICY_GUARD_PREFIX}${field}`] as ReadFieldCheckFunc | undefined; if (!func) { return true; } else { @@ -877,6 +885,17 @@ export class PolicyUtil { return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; } + /** + * Returns if the given model has field-level policy. + */ + hasFieldLevelPolicy(model: string) { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + throw this.unknownError(`unable to load policy guard for ${model}`); + } + return !!guard[HAS_FIELD_LEVEL_POLICY_FLAG]; + } + /** * Gets Zod schema for the given model and access kind. * @@ -895,10 +914,17 @@ export class PolicyUtil { */ postProcessForRead(data: any, model: string, queryArgs: any) { const origData = this.clone(data); - this.doPostProcessForRead(data, model, origData, queryArgs); + this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); } - private doPostProcessForRead(data: any, model: string, fullData: any, queryArgs: any, path = '') { + private doPostProcessForRead( + data: any, + model: string, + fullData: any, + queryArgs: any, + hasFieldLevelPolicy: boolean, + path = '' + ) { if (data === null || data === undefined) { return; } @@ -926,34 +952,39 @@ export class PolicyUtil { continue; } - if (!fieldInfo.isDataModel) { - // scalar field, delete unselected ones - const select = queryArgs?.select; - if (select && typeof select === 'object' && select[field] !== true) { - // there's a select clause but this field is not included - delete entityData[field]; - continue; + if (hasFieldLevelPolicy) { + // 1. remove fields selected for checking field-level policies but not selected by the original query args + // 2. evaluate field-level policies and remove fields that are not readable + + if (!fieldInfo.isDataModel) { + // scalar field, delete unselected ones + const select = queryArgs?.select; + if (select && typeof select === 'object' && select[field] !== true) { + // there's a select clause but this field is not included + delete entityData[field]; + continue; + } + } else { + // relation field, delete if not included + const include = queryArgs?.include; + const select = queryArgs?.select; + if (!include?.[field] && !select?.[field]) { + // relation field not included or selected + delete entityData[field]; + continue; + } } - } else { - // relation field, delete if not included - const include = queryArgs?.include; - const select = queryArgs?.select; - if (!include?.[field] && !select?.[field]) { - // relation field not included or selected + + // delete unreadable fields + if (!this.checkReadField(model, field, entityFullData)) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] dropping unreadable field ${path ? path + '.' : ''}${field}`); + } delete entityData[field]; continue; } } - // delete unreadable fields - if (!this.checkReadField(model, field, entityFullData)) { - if (this.shouldLogQuery) { - this.logger.info(`[policy] dropping unreadable field ${path ? path + '.' : ''}${field}`); - } - delete entityData[field]; - continue; - } - if (fieldInfo.isDataModel) { const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; this.doPostProcessForRead( @@ -961,6 +992,7 @@ export class PolicyUtil { fieldInfo.type, entityFullData[field], nextArgs, + hasFieldLevelPolicy, path ? path + '.' + field : field ); } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index f227eb5ee..7aa5ed6ee 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,5 +1,10 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; +import { + FIELD_LEVEL_POLICY_GUARD_SELECTOR, + HAS_FIELD_LEVEL_POLICY_FLAG, + PRE_UPDATE_VALUE_SELECTOR, +} from '../constants'; import type { DbOperations, FieldInfo, PolicyOperationKind, QueryContext } from '../types'; /** @@ -44,9 +49,11 @@ export type PolicyDef = { } & Partial> & { create_input: InputCheckFunc; } & { - preValueSelect?: object; - readFieldSelect?: object; - } & Record + [PRE_UPDATE_VALUE_SELECTOR]?: object; + [FIELD_LEVEL_POLICY_GUARD_SELECTOR]?: object; + } & Record & { + [HAS_FIELD_LEVEL_POLICY_FLAG]?: boolean; + } >; validation: Record; }; diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 554e0b759..82f51a4cb 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -17,7 +17,13 @@ import { isThisExpr, isUnaryExpr, } from '@zenstackhq/language/ast'; -import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime'; +import { + FIELD_LEVEL_POLICY_GUARD_PREFIX, + FIELD_LEVEL_POLICY_GUARD_SELECTOR, + HAS_FIELD_LEVEL_POLICY_FLAG, + type PolicyKind, + type PolicyOperationKind, +} from '@zenstackhq/runtime'; import { ExpressionContext, PluginError, @@ -295,12 +301,15 @@ export default class PolicyGenerator { const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`readFieldCheck$${field.name}`] = guardFunc.getName()!; + result[`${FIELD_LEVEL_POLICY_GUARD_PREFIX}${field.name}`] = guardFunc.getName()!; } - const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); - if (readFieldCheckSelect) { - result[`readFieldSelect`] = readFieldCheckSelect; + if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) { + result[HAS_FIELD_LEVEL_POLICY_FLAG] = true; + const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); + if (readFieldCheckSelect) { + result[FIELD_LEVEL_POLICY_GUARD_SELECTOR] = readFieldCheckSelect; + } } } From 55fa439ada6fddc09da7890acb3f78d00a34618e Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 19 Aug 2023 19:37:29 +0800 Subject: [PATCH 5/6] code refactor --- packages/runtime/src/constants.ts | 13 +- .../src/enhancements/policy/handler.ts | 20 +- .../src/enhancements/policy/policy-utils.ts | 57 ++++-- packages/runtime/src/enhancements/types.ts | 4 +- .../attribute-application-validator.ts | 4 +- .../function-invocation-validator.ts | 4 +- .../src/language-server/zmodel-linker.ts | 2 +- .../access-policy/policy-guard-generator.ts | 172 ++++++++---------- 8 files changed, 135 insertions(+), 141 deletions(-) diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index ad94a367f..859184c3c 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -74,19 +74,24 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; export const PRISMA_MINIMUM_VERSION = '4.8.0'; /** - * Selector function name for fetching pre-update value of entities. + * Selector function name for fetching pre-update entity values. */ export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect'; /** - * Prefix for field-level access control guard function name + * Prefix for field-level read checker function name */ -export const FIELD_LEVEL_POLICY_GUARD_PREFIX = 'readFieldCheck$'; +export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$'; /** * Field-level access control evaluation selector function name */ -export const FIELD_LEVEL_POLICY_GUARD_SELECTOR = 'readFieldSelect'; +export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect'; + +/** + * Prefix for field-level update guard function name + */ +export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldCheck$'; /** * Flag that indicates if the model has field-level access control diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 476573a62..78a5d1400 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -70,9 +70,7 @@ export class PolicyProxyHandler implements Pr return null; } - if (this.utils.hasFieldLevelPolicy(this.model)) { - this.utils.injectReadCheckSelect(this.model, args); - } + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); @@ -96,9 +94,7 @@ export class PolicyProxyHandler implements Pr throw this.utils.notFound(this.model); } - if (this.utils.hasFieldLevelPolicy(this.model)) { - this.utils.injectReadCheckSelect(this.model, args); - } + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`); @@ -115,9 +111,7 @@ export class PolicyProxyHandler implements Pr return null; } - if (this.utils.hasFieldLevelPolicy(this.model)) { - this.utils.injectReadCheckSelect(this.model, args); - } + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`); @@ -134,9 +128,7 @@ export class PolicyProxyHandler implements Pr throw this.utils.notFound(this.model); } - if (this.utils.hasFieldLevelPolicy(this.model)) { - this.utils.injectReadCheckSelect(this.model, args); - } + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`); @@ -153,9 +145,7 @@ export class PolicyProxyHandler implements Pr return []; } - if (this.utils.hasFieldLevelPolicy(this.model)) { - this.utils.injectReadCheckSelect(this.model, args); - } + this.utils.injectReadCheckSelect(this.model, args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index ce6e7a346..142ab154a 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -7,8 +7,9 @@ import { fromZodError } from 'zod-validation-error'; import { AUXILIARY_FIELDS, CrudFailureReason, - FIELD_LEVEL_POLICY_GUARD_PREFIX, - FIELD_LEVEL_POLICY_GUARD_SELECTOR, + FIELD_LEVEL_READ_CHECKER_PREFIX, + FIELD_LEVEL_READ_CHECKER_SELECTOR, + FIELD_LEVEL_UPDATE_GUARD_PREFIX, HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, PrismaErrorCode, @@ -193,13 +194,16 @@ export class PolicyUtil { return this.reduce(r); } + /** + * Get field-level auth guard + */ getFieldUpdateAuthGuard(db: Record, model: string, field: string): object { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - const provider = guard[`updateFieldGuard$${field}`]; + const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`]; if (typeof provider === 'boolean') { return this.reduce(provider); } @@ -261,9 +265,11 @@ export class PolicyUtil { // merge field-level policy guards const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); if (fieldUpdateGuard.rejectedByField) { + // rejected args.where = this.makeFalse(); return false; } else if (fieldUpdateGuard.guard) { + // merge guard = this.and(guard, fieldUpdateGuard.guard); } } @@ -594,6 +600,7 @@ export class PolicyUtil { // merge field-level policy guards const fieldUpdateGuard = this.getFieldUpdateGuards(db, model, args); if (fieldUpdateGuard.rejectedByField) { + // rejected throw this.deniedByPolicy( model, 'update', @@ -602,6 +609,7 @@ export class PolicyUtil { }"` ); } else if (fieldUpdateGuard.guard) { + // merge guard = this.and(guard, fieldUpdateGuard.guard); } } @@ -655,7 +663,7 @@ export class PolicyUtil { } private getFieldUpdateGuards(db: Record, model: string, args: any) { - let allFieldGuards; + const allFieldGuards = []; for (const [k, v] of Object.entries(args.data ?? args)) { if (typeof v === 'undefined') { continue; @@ -664,9 +672,9 @@ export class PolicyUtil { if (this.isFalse(fieldGuard)) { return { guard: allFieldGuards, rejectedByField: k }; } - allFieldGuards = this.and(allFieldGuards, fieldGuard); + allFieldGuards.push(fieldGuard); } - return { guard: allFieldGuards, rejectedByField: undefined }; + return { guard: this.and(...allFieldGuards), rejectedByField: undefined }; } /** @@ -745,21 +753,30 @@ export class PolicyUtil { return { result, error: undefined }; } + /** + * Injects field selection needed for checking field-level read policy into query args. + * @returns + */ injectReadCheckSelect(model: string, args: any) { + if (!this.hasFieldLevelPolicy(model)) { + return; + } + const readFieldSelect = this.getReadFieldSelect(model); if (!readFieldSelect) { return; } + this.doInjectReadCheckSelect(model, args, { select: readFieldSelect }); } private doInjectReadCheckSelect(model: string, args: any, input: any) { - if (!input.select) { + if (!input?.select) { return; } - let target: any; - let isInclude = false; + let target: any; // injection target + let isInclude = false; // if the target is include or select if (args.select) { target = args.select; @@ -783,10 +800,11 @@ export class PolicyUtil { } } + // recurse into nested selects (relation fields) for (const [k, v] of Object.entries(input.select)) { if (typeof v === 'object' && v?.select) { const field = resolveField(this.modelMeta, model, k); - if (field && field.isDataModel) { + if (field?.isDataModel) { // recurse into relation if (isInclude && target[k] === true) { // select all fields for the relation @@ -795,6 +813,7 @@ export class PolicyUtil { // ensure an empty select clause target[k] = { select: {} }; } + // recurse this.doInjectReadCheckSelect(field.type, target[k], v); } } @@ -860,20 +879,20 @@ export class PolicyUtil { return guard[PRE_UPDATE_VALUE_SELECTOR]; } - getReadFieldSelect(model: string): object | undefined { + private getReadFieldSelect(model: string): object | undefined { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - return guard[FIELD_LEVEL_POLICY_GUARD_SELECTOR]; + return guard[FIELD_LEVEL_READ_CHECKER_SELECTOR]; } - checkReadField(model: string, field: string, entity: any) { + private checkReadField(model: string, field: string, entity: any) { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); } - const func = guard[`${FIELD_LEVEL_POLICY_GUARD_PREFIX}${field}`] as ReadFieldCheckFunc | undefined; + const func = guard[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field}`] as ReadFieldCheckFunc | undefined; if (!func) { return true; } else { @@ -885,10 +904,7 @@ export class PolicyUtil { return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; } - /** - * Returns if the given model has field-level policy. - */ - hasFieldLevelPolicy(model: string) { + private hasFieldLevelPolicy(model: string) { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); @@ -913,6 +929,8 @@ export class PolicyUtil { * Post processing checks and clean-up for read model entities. */ postProcessForRead(data: any, model: string, queryArgs: any) { + // preserve the original data as it may be needed for checking field-level readability, + // while the "data" will be manipulated during traversal (deleting unreadable fields) const origData = this.clone(data); this.doPostProcessForRead(data, model, origData, queryArgs, this.hasFieldLevelPolicy(model)); } @@ -965,7 +983,7 @@ export class PolicyUtil { continue; } } else { - // relation field, delete if not included + // relation field, delete if not selected or included const include = queryArgs?.include; const select = queryArgs?.select; if (!include?.[field] && !select?.[field]) { @@ -986,6 +1004,7 @@ export class PolicyUtil { } if (fieldInfo.isDataModel) { + // recurse into nested fields const nextArgs = (queryArgs?.select ?? queryArgs?.include)?.[field]; this.doPostProcessForRead( fieldData, diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 7aa5ed6ee..d879bf510 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; import { - FIELD_LEVEL_POLICY_GUARD_SELECTOR, + FIELD_LEVEL_READ_CHECKER_SELECTOR, HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; @@ -50,7 +50,7 @@ export type PolicyDef = { create_input: InputCheckFunc; } & { [PRE_UPDATE_VALUE_SELECTOR]?: object; - [FIELD_LEVEL_POLICY_GUARD_SELECTOR]?: object; + [FIELD_LEVEL_READ_CHECKER_SELECTOR]?: object; } & Record & { [HAS_FIELD_LEVEL_POLICY_FLAG]?: boolean; } diff --git a/packages/schema/src/language-server/validator/attribute-application-validator.ts b/packages/schema/src/language-server/validator/attribute-application-validator.ts index 76024c236..a820620d1 100644 --- a/packages/schema/src/language-server/validator/attribute-application-validator.ts +++ b/packages/schema/src/language-server/validator/attribute-application-validator.ts @@ -15,10 +15,10 @@ import { isEnum, isReferenceExpr, } from '@zenstackhq/language/ast'; +import { isFutureExpr, resolved } from '@zenstackhq/sdk'; import { ValidationAcceptor, streamAllContents } from 'langium'; -import { AstValidator } from '../types'; import pluralize from 'pluralize'; -import { isFutureExpr, resolved } from '@zenstackhq/sdk'; +import { AstValidator } from '../types'; import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils'; // a registry of function handlers marked with @func diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 3ccf5a260..5e0f1d639 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -11,12 +11,12 @@ import { isDataModelFieldAttribute, isLiteralExpr, } from '@zenstackhq/language/ast'; +import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, ValidationAcceptor } from 'langium'; +import { P, match } from 'ts-pattern'; import { getDataModelFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; -import { match, P } from 'ts-pattern'; -import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; /** * InvocationExpr validation diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 7b9d42956..da8b3bef8 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -29,6 +29,7 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; +import { getContainingModel, isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -46,7 +47,6 @@ import { import { CancellationToken } from 'vscode-jsonrpc'; import { getAllDeclarationsFromImports } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; -import { getContainingModel, isFromStdlib } from '@zenstackhq/sdk'; interface DefaultReference extends Reference { _ref?: AstNode | LinkingError; diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 82f51a4cb..85eedf96e 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -18,8 +18,9 @@ import { isUnaryExpr, } from '@zenstackhq/language/ast'; import { - FIELD_LEVEL_POLICY_GUARD_PREFIX, - FIELD_LEVEL_POLICY_GUARD_SELECTOR, + FIELD_LEVEL_READ_CHECKER_PREFIX, + FIELD_LEVEL_READ_CHECKER_SELECTOR, + FIELD_LEVEL_UPDATE_GUARD_PREFIX, HAS_FIELD_LEVEL_POLICY_FLAG, type PolicyKind, type PolicyOperationKind, @@ -259,7 +260,7 @@ export default class PolicyGenerator { result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { - const preValueSelect = this.generatePreValueSelect(allows, denies); + const preValueSelect = this.generateSelectForRules(allows, denies); if (preValueSelect) { result['preValueSelect'] = preValueSelect; } @@ -292,23 +293,23 @@ export default class PolicyGenerator { for (const field of model.fields) { const allows = this.getPolicyExpressions(field, 'allow', 'read'); const denies = this.getPolicyExpressions(field, 'deny', 'read'); - allFieldsAllows.push(...allows); - allFieldsDenies.push(...denies); - if (denies.length === 0 && allows.length === 0) { continue; } + allFieldsAllows.push(...allows); + allFieldsDenies.push(...denies); + const guardFunc = this.generateReadFieldGuardFunction(sourceFile, field, allows, denies); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`${FIELD_LEVEL_POLICY_GUARD_PREFIX}${field.name}`] = guardFunc.getName()!; + result[`${FIELD_LEVEL_READ_CHECKER_PREFIX}${field.name}`] = guardFunc.getName()!; } if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) { result[HAS_FIELD_LEVEL_POLICY_FLAG] = true; - const readFieldCheckSelect = this.generatePreValueSelect(allFieldsAllows, allFieldsDenies); + const readFieldCheckSelect = this.generateSelectForRules(allFieldsAllows, allFieldsDenies); if (readFieldCheckSelect) { - result[FIELD_LEVEL_POLICY_GUARD_SELECTOR] = readFieldCheckSelect; + result[FIELD_LEVEL_READ_CHECKER_SELECTOR] = readFieldCheckSelect; } } } @@ -319,41 +320,18 @@ export default class PolicyGenerator { allows: Expression[], denies: Expression[] ) { - const statements: (string | WriterFunction | StatementStructures)[] = []; - - // check if any allow or deny rule contains 'auth()' invocation - const hasAuthRef = [...denies, ...allows].some((rule) => - [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) - ); - - if (hasAuthRef) { - const root = findRootNode(field) as Model; - const userModel = root.declarations.find( - (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' - ); - if (!userModel) { - throw new PluginError(name, 'User model not found'); - } - const userIdFields = getIdFields(userModel); - if (!userIdFields || userIdFields.length === 0) { - throw new PluginError(name, 'User model does not have an id field'); - } + const statements: (string | WriterFunction)[] = []; - // normalize user to null to avoid accidentally use undefined in filter - statements.push( - `const user = hasAllFields(context.user, [${userIdFields - .map((f) => "'" + f.name + "'") - .join(', ')}]) ? context.user as any : null;` - ); - } + this.generateNormalizedAuthRef(field.$container as DataModel, allows, denies, statements); + // compile rules down to typescript expressions statements.push((writer) => { const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, fieldReferenceContext: 'input', }); - let expr = + const denyStmt = denies.length > 0 ? '!(' + denies @@ -364,13 +342,29 @@ export default class PolicyGenerator { ')' : undefined; - const allowStmt = allows - .map((allow) => { - return transformer.transform(allow); - }) - .join(' || '); + const allowStmt = + allows.length > 0 + ? '(' + + allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || ') + + ')' + : undefined; + + let expr: string | undefined; + + if (denyStmt && allowStmt) { + expr = `${denyStmt} && ${allowStmt}`; + } else if (denyStmt) { + expr = denyStmt; + } else if (allowStmt) { + expr = allowStmt; + } else { + throw new Error('should not happen'); + } - expr = expr ? `${expr} && (${allowStmt})` : allowStmt; writer.write('return ' + expr); }); @@ -408,7 +402,7 @@ export default class PolicyGenerator { const guardFunc = this.generateQueryGuardFunction(sourceFile, model, 'update', allows, denies, field); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[`updateFieldGuard$${field.name}`] = guardFunc.getName()!; + result[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field.name}`] = guardFunc.getName()!; } } @@ -448,9 +442,9 @@ export default class PolicyGenerator { }); } - // generates an object that can be used as the 'select' argument when fetching pre-update - // entity value - private generatePreValueSelect(allows: Expression[], denies: Expression[]): object { + // generates a "select" object that contains (recursively) fields referenced by the + // given policy rules + private generateSelectForRules(allows: Expression[], denies: Expression[]): object { // eslint-disable-next-line @typescript-eslint/no-explicit-any const result: any = {}; const addPath = (path: string[]) => { @@ -523,32 +517,9 @@ export default class PolicyGenerator { denies: Expression[], forField?: DataModelField ): FunctionDeclaration { - const statements: (string | WriterFunction | StatementStructures)[] = []; + const statements: (string | WriterFunction)[] = []; - // check if any allow or deny rule contains 'auth()' invocation - const hasAuthRef = [...denies, ...allows].some((rule) => - [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) - ); - - if (hasAuthRef) { - const userModel = model.$container.declarations.find( - (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' - ); - if (!userModel) { - throw new PluginError(name, 'User model not found'); - } - const userIdFields = getIdFields(userModel); - if (!userIdFields || userIdFields.length === 0) { - throw new PluginError(name, 'User model does not have an id field'); - } - - // normalize user to null to avoid accidentally use undefined in filter - statements.push( - `const user = hasAllFields(context.user, [${userIdFields - .map((f) => "'" + f.name + "'") - .join(', ')}]) ? context.user as any : null;` - ); - } + this.generateNormalizedAuthRef(model, allows, denies, statements); const hasFieldAccess = [...denies, ...allows].some((rule) => [...this.allNodes(rule)].some( @@ -656,32 +627,9 @@ export default class PolicyGenerator { allows: Expression[], denies: Expression[] ): FunctionDeclaration { - const statements: (string | WriterFunction | StatementStructures)[] = []; + const statements: (string | WriterFunction)[] = []; - // check if any allow or deny rule contains 'auth()' invocation - const hasAuthRef = [...denies, ...allows].some((rule) => - [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) - ); - - if (hasAuthRef) { - const userModel = model.$container.declarations.find( - (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' - ); - if (!userModel) { - throw new PluginError(name, 'User model not found'); - } - const userIdFields = getIdFields(userModel); - if (!userIdFields || userIdFields.length === 0) { - throw new PluginError(name, 'User model does not have an id field'); - } - - // normalize user to null to avoid accidentally use undefined in filter - statements.push( - `const user = hasAllFields(context.user, [${userIdFields - .map((f) => "'" + f.name + "'") - .join(', ')}]) ? context.user as any : null;` - ); - } + this.generateNormalizedAuthRef(model, allows, denies, statements); statements.push((writer) => { if (allows.length === 0) { @@ -734,6 +682,38 @@ export default class PolicyGenerator { return func; } + private generateNormalizedAuthRef( + model: DataModel, + allows: Expression[], + denies: Expression[], + statements: (string | WriterFunction)[] + ) { + // check if any allow or deny rule contains 'auth()' invocation + const hasAuthRef = [...allows, ...denies].some((rule) => + [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) + ); + + if (hasAuthRef) { + const userModel = model.$container.declarations.find( + (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' + ); + if (!userModel) { + throw new PluginError(name, 'User model not found'); + } + const userIdFields = getIdFields(userModel); + if (!userIdFields || userIdFields.length === 0) { + throw new PluginError(name, 'User model does not have an id field'); + } + + // normalize user to null to avoid accidentally use undefined in filter + statements.push( + `const user = hasAllFields(context.user, [${userIdFields + .map((f) => "'" + f.name + "'") + .join(', ')}]) ? context.user as any : null;` + ); + } + } + private *allNodes(expr: Expression) { yield expr; yield* streamAllContents(expr); From 466bb3442af6271fff45bfe45b41e5a153505eb7 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Sat, 19 Aug 2023 20:45:55 +0800 Subject: [PATCH 6/6] fix build --- .../plugins/access-policy/policy-guard-generator.ts | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 85eedf96e..28ac12805 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -45,16 +45,10 @@ import { resolved, saveProject, } from '@zenstackhq/sdk'; -import { findRootNode, streamAllContents } from 'langium'; +import { streamAllContents } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { - FunctionDeclaration, - SourceFile, - StatementStructures, - VariableDeclarationKind, - WriterFunction, -} from 'ts-morph'; +import { FunctionDeclaration, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { name } from '.'; import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; import {