diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts index 6827584d1..62e570e6d 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/Post.router.ts @@ -23,6 +23,10 @@ export default function createRouter( .input($Schema.PostInputSchema.aggregate) .query(({ ctx, input }) => checkRead(db(ctx).post.aggregate(input as any))), + createMany: procedure + .input($Schema.PostInputSchema.createMany) + .mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.createMany(input as any))), + create: procedure .input($Schema.PostInputSchema.create) .mutation(async ({ ctx, input }) => checkMutate(db(ctx).post.create(input as any))), @@ -88,6 +92,29 @@ export interface ClientType, Error>, ) => UseTRPCInfiniteQueryResult, TRPCClientErrorLike>; }; + createMany: { + useMutation: ( + opts?: UseTRPCMutationOptions< + Prisma.PostCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >, + ) => Omit< + UseTRPCMutationResult< + Prisma.BatchPayload, + TRPCClientErrorLike, + Prisma.SelectSubset, + Context + >, + 'mutateAsync' + > & { + mutateAsync: ( + variables: T, + opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>, + ) => Promise; + }; + }; create: { useMutation: ( opts?: UseTRPCMutationOptions< diff --git a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts index 06ce01f31..4c686b057 100644 --- a/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts +++ b/packages/plugins/trpc/tests/projects/t3-trpc-v10/src/server/api/routers/generated/routers/User.router.ts @@ -23,6 +23,10 @@ export default function createRouter( .input($Schema.UserInputSchema.aggregate) .query(({ ctx, input }) => checkRead(db(ctx).user.aggregate(input as any))), + createMany: procedure + .input($Schema.UserInputSchema.createMany) + .mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.createMany(input as any))), + create: procedure .input($Schema.UserInputSchema.create) .mutation(async ({ ctx, input }) => checkMutate(db(ctx).user.create(input as any))), @@ -88,6 +92,29 @@ export interface ClientType, Error>, ) => UseTRPCInfiniteQueryResult, TRPCClientErrorLike>; }; + createMany: { + useMutation: ( + opts?: UseTRPCMutationOptions< + Prisma.UserCreateManyArgs, + TRPCClientErrorLike, + Prisma.BatchPayload, + Context + >, + ) => Omit< + UseTRPCMutationResult< + Prisma.BatchPayload, + TRPCClientErrorLike, + Prisma.SelectSubset, + Context + >, + 'mutateAsync' + > & { + mutateAsync: ( + variables: T, + opts?: UseTRPCMutationOptions, Prisma.BatchPayload, Context>, + ) => Promise; + }; + }; create: { useMutation: ( opts?: UseTRPCMutationOptions< diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index ef48f7f38..6b7e67bea 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -690,16 +690,25 @@ export class PolicyProxyHandler implements Pr const postWriteChecks: PostWriteCheckRecord[] = []; // registers a post-update check task - const _registerPostUpdateCheck = async (model: string, uniqueFilter: any) => { + const _registerPostUpdateCheck = async ( + model: string, + preUpdateLookupFilter: any, + postUpdateLookupFilter: any + ) => { // both "post-update" rules and Zod schemas require a post-update check if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { // select pre-update field values let preValue: any; const preValueSelect = this.utils.getPreValueSelect(model); if (preValueSelect && Object.keys(preValueSelect).length > 0) { - preValue = await db[model].findFirst({ where: uniqueFilter, select: preValueSelect }); + preValue = await db[model].findFirst({ where: preUpdateLookupFilter, select: preValueSelect }); } - postWriteChecks.push({ model, operation: 'postUpdate', uniqueFilter, preValue }); + postWriteChecks.push({ + model, + operation: 'postUpdate', + uniqueFilter: postUpdateLookupFilter, + preValue, + }); } }; @@ -826,7 +835,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkPolicyForUnique(model, args, 'update', db, checkArgs); // register post-update check - await _registerPostUpdateCheck(model, args); + await _registerPostUpdateCheck(model, args, args); } } }; @@ -873,7 +882,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // handles the case where id fields are updated - const ids = this.utils.clone(existing); + const postUpdateIds = this.utils.clone(existing); for (const key of Object.keys(existing)) { const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key]; if ( @@ -881,12 +890,12 @@ export class PolicyProxyHandler implements Pr typeof updateValue === 'number' || typeof updateValue === 'bigint' ) { - ids[key] = updateValue; + postUpdateIds[key] = updateValue; } } // register post-update check - await _registerPostUpdateCheck(model, ids); + await _registerPostUpdateCheck(model, existing, postUpdateIds); } }, @@ -978,7 +987,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // register post-update check - await _registerPostUpdateCheck(model, uniqueFilter); + await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); // convert upsert to update const convertedUpdate = { diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 2ab3e2bdd..a5de026f0 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -70,6 +70,8 @@ export class ExpressionWriter { this.plainExprBuilder = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, isPostGuard: this.isPostGuard, + // in post-guard context, `this` references pre-update value + thisExprContext: this.isPostGuard ? 'context.preValue' : undefined, }); } diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 2025c3d5c..20893da10 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -6,7 +6,6 @@ import { Enum, Expression, Model, - isBinaryExpr, isDataModel, isDataModelField, isEnum, @@ -15,7 +14,6 @@ import { isMemberAccessExpr, isReferenceExpr, isThisExpr, - isUnaryExpr, } from '@zenstackhq/language/ast'; import { FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, @@ -281,30 +279,6 @@ export default class PolicyGenerator { } } - private visitPolicyExpression(expr: Expression, postUpdate: boolean): Expression | undefined { - if (isBinaryExpr(expr) && (expr.operator === '&&' || expr.operator === '||')) { - const left = this.visitPolicyExpression(expr.left, postUpdate); - const right = this.visitPolicyExpression(expr.right, postUpdate); - if (!left) return right; - if (!right) return left; - return { ...expr, left, right }; - } - - if (isUnaryExpr(expr) && expr.operator === '!') { - const operand = this.visitPolicyExpression(expr.operand, postUpdate); - if (!operand) return undefined; - return { ...expr, operand }; - } - - if (postUpdate && !this.hasFutureReference(expr)) { - return undefined; - } else if (!postUpdate && this.hasFutureReference(expr)) { - return undefined; - } - - return expr; - } - private hasFutureReference(expr: Expression) { for (const node of streamAst(expr)) { if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { @@ -599,13 +573,19 @@ export default class PolicyGenerator { // visit a reference or member access expression to build a // selection path const visit = (node: Expression): string[] | undefined => { + if (isThisExpr(node)) { + return []; + } + if (isReferenceExpr(node)) { const target = resolved(node.target); if (isDataModelField(target)) { // a field selection, it's a terminal return [target.name]; } - } else if (isMemberAccessExpr(node)) { + } + + if (isMemberAccessExpr(node)) { if (forAuthContext && isAuthInvocation(node.operand)) { return [node.member.$refText]; } @@ -621,6 +601,7 @@ export default class PolicyGenerator { return [...inner, node.member.$refText]; } } + return undefined; }; diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index ec4f89fcb..27e018aa1 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -112,9 +112,7 @@ export class TypeScriptExpressionTransformer { throw new TypeScriptExpressionTransformerError(`Unresolved MemberAccessExpr`); } - if (isThisExpr(expr.operand)) { - return expr.member.ref.name; - } else if (isFutureExpr(expr.operand)) { + if (isFutureExpr(expr.operand)) { if (this.options?.isPostGuard !== true) { throw new TypeScriptExpressionTransformerError(`future() is only supported in postUpdate rules`); } diff --git a/tests/integration/tests/regression/issue-1235.test.ts b/tests/integration/tests/regression/issue-1235.test.ts new file mode 100644 index 000000000..1e9f80f86 --- /dev/null +++ b/tests/integration/tests/regression/issue-1235.test.ts @@ -0,0 +1,35 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('issue 1235', () => { + it('regression1', async () => { + const { enhance } = await loadSchema( + ` + model Post { + id Int @id @default(autoincrement()) + @@deny("update", future().id != id) + @@allow("all", true) + } + ` + ); + + const db = enhance(); + const post = await db.post.create({ data: {} }); + await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).toBeRejectedByPolicy(); + }); + + it('regression2', async () => { + const { enhance } = await loadSchema( + ` + model Post { + id Int @id @default(autoincrement()) + @@deny("update", future().id != this.id) + @@allow("all", true) + } + ` + ); + + const db = enhance(); + const post = await db.post.create({ data: {} }); + await expect(db.post.update({ data: { id: post.id + 1 }, where: { id: post.id } })).toBeRejectedByPolicy(); + }); +});