From 03596266cccc6f9deaa6d6e603641c9be2b853ce Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:30:14 +0800 Subject: [PATCH] fix: use zod parse result data as mutation input --- .../src/enhancements/policy/handler.ts | 26 ++++++-- .../src/enhancements/policy/policy-utils.ts | 21 +++++++ packages/schema/src/plugins/zod/generator.ts | 6 +- .../with-policy/field-validation.test.ts | 59 +++++++++++++++++++ 4 files changed, 106 insertions(+), 6 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index e9f4daae0..698dcd364 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -249,7 +249,7 @@ export class PolicyProxyHandler implements Pr // there's no nested write and we've passed input check, proceed with the create directly // validate zod schema if any - this.validateCreateInputSchema(this.model, args.data); + args.data = this.validateCreateInputSchema(this.model, args.data); // make a create args only containing data and ID selection const createArgs: any = { data: args.data, select: this.utils.makeIdSelection(this.model) }; @@ -305,12 +305,20 @@ export class PolicyProxyHandler implements Pr // visit the create payload const visitor = new NestedWriteVisitor(this.modelMeta, { create: async (model, args, context) => { - this.validateCreateInputSchema(model, args); + const validateResult = this.validateCreateInputSchema(model, args); + if (validateResult !== args) { + this.utils.replace(args, validateResult); + } pushIdFields(model, context); }, createMany: async (model, args, context) => { - enumerate(args.data).forEach((item) => this.validateCreateInputSchema(model, item)); + enumerate(args.data).forEach((item) => { + const r = this.validateCreateInputSchema(model, item); + if (r !== item) { + this.utils.replace(item, r); + } + }); pushIdFields(model, context); }, @@ -319,7 +327,9 @@ export class PolicyProxyHandler implements Pr throw this.utils.validationError(`'where' field is required for connectOrCreate`); } - this.validateCreateInputSchema(model, args.create); + if (args.create) { + args.create = this.validateCreateInputSchema(model, args.create); + } const existing = await this.utils.checkExistence(db, model, args.where); if (existing) { @@ -468,6 +478,9 @@ export class PolicyProxyHandler implements Pr parseResult.error ); } + return parseResult.data; + } else { + return data; } } @@ -495,7 +508,10 @@ export class PolicyProxyHandler implements Pr CrudFailureReason.ACCESS_POLICY_VIOLATION ); } else if (inputCheck === true) { - this.validateCreateInputSchema(this.model, item); + const r = this.validateCreateInputSchema(this.model, item); + if (r !== item) { + this.utils.replace(item, r); + } } else if (inputCheck === undefined) { // static policy check is not possible, need to do post-create check needPostCreateCheck = true; diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 388f9cd90..63b83b79f 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1276,6 +1276,27 @@ export class PolicyUtil { return value ? deepcopy(value) : {}; } + /** + * Replace content of `target` object with `withObject` in-place. + */ + replace(target: any, withObject: any) { + if (!target || typeof target !== 'object' || !withObject || typeof withObject !== 'object') { + return; + } + + // remove missing keys + for (const key of Object.keys(target)) { + if (!(key in withObject)) { + delete target[key]; + } + } + + // overwrite keys + for (const [key, value] of Object.entries(withObject)) { + target[key] = value; + } + } + /** * Picks properties from an object. */ diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 2727a781f..d1af70882 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -395,7 +395,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s //////////////////////////////////////////////// // schema for validating prisma create input (all fields optional) - let prismaCreateSchema = makePartial('baseSchema'); + let prismaCreateSchema = makePassthrough(makePartial('baseSchema')); if (refineFuncName) { prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`; } @@ -501,3 +501,7 @@ function makeOmit(schema: string, fields: string[]) { function makeMerge(schema1: string, schema2: string): string { return `${schema1}.merge(${schema2})`; } + +function makePassthrough(schema: string) { + return `${schema}.passthrough()`; +} diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index 8727f1561..55b9c5cee 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -35,6 +35,8 @@ describe('With Policy: field validation', () => { text3 String @length(min: 3) text4 String @length(max: 5) text5 String? @endsWith('xyz') + text6 String? @trim @lower + text7 String? @upper @@allow('all', true) } @@ -495,4 +497,61 @@ describe('With Policy: field validation', () => { }) ).toResolveTruthy(); }); + + it('string transformation', async () => { + await db.user.create({ + data: { + id: '1', + password: 'abc123!@#', + email: 'who@myorg.com', + handle: 'user1', + }, + }); + + await expect( + db.userData.create({ + data: { + userId: '1', + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, + }) + ).resolves.toMatchObject({ text6: 'abc', text7: 'ABC' }); + + await expect( + db.user.create({ + data: { + id: '2', + password: 'abc123!@#', + email: 'who@myorg.com', + handle: 'user2', + userData: { + create: { + a: 1, + b: 0, + c: -1, + d: 0, + text1: 'abc123', + text2: 'def', + text3: 'aaa', + text4: 'abcab', + text6: ' AbC ', + text7: 'abc', + }, + }, + }, + include: { userData: true }, + }) + ).resolves.toMatchObject({ + userData: expect.objectContaining({ text6: 'abc', text7: 'ABC' }), + }); + }); });