From 69fbc12f58c186bbe2240a7f6cf61a15963e35b9 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 28 Nov 2023 20:05:21 -0800 Subject: [PATCH] fix: query injection error when create (in array form) is nested inside an update --- packages/runtime/package.json | 1 + packages/runtime/src/cross/model-meta.ts | 13 +- .../runtime/src/cross/nested-write-visitor.ts | 79 +++----- .../src/enhancements/policy/handler.ts | 67 +++++-- pnpm-lock.yaml | 3 + .../enhancements/with-policy/refactor.test.ts | 13 +- .../tests/regression/issue-864.test.ts | 185 ++++++++++++++++++ 7 files changed, 293 insertions(+), 68 deletions(-) create mode 100644 tests/integration/tests/regression/issue-864.test.ts diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 2656ec225..c8b6a196f 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -64,6 +64,7 @@ "pluralize": "^8.0.0", "semver": "^7.3.8", "superjson": "^1.11.0", + "tiny-invariant": "^1.3.1", "tslib": "^2.4.1", "upper-case-first": "^2.0.2", "uuid": "^9.0.0", diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 37636de5a..817819b8c 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -102,10 +102,21 @@ export type ModelMeta = { /** * Resolves a model field to its metadata. Returns undefined if not found. */ -export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined { +export function resolveField(modelMeta: ModelMeta, model: string, field: string) { return modelMeta.fields[lowerCaseFirst(model)]?.[field]; } +/** + * Resolves a model field to its metadata. Throws an error if not found. + */ +export function requireField(modelMeta: ModelMeta, model: string, field: string) { + const f = resolveField(modelMeta, model, field); + if (!f) { + throw new Error(`Field ${model}.${field} cannot be resolved`); + } + return f; +} + /** * Gets all fields of a model. */ diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index 9b2a9a628..7d67f6d9b 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -145,49 +145,53 @@ export class NestedWriteVisitor { return; } - const context = { parent, field, nestingPath: [...nestingPath] }; const toplevel = field == undefined; + const context = { parent, field, nestingPath: [...nestingPath] }; + const pushNewContext = (field: FieldInfo | undefined, model: string, where: any, unique = false) => { + return { ...context, nestingPath: [...context.nestingPath, { field, model, where, unique }] }; + }; + // visit payload switch (action) { case 'create': - context.nestingPath.push({ field, model, where: {}, unique: false }); for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, {}); let callbackResult: any; if (this.callback.create) { - callbackResult = await this.callback.create(model, item, context); + callbackResult = await this.callback.create(model, item, newContext); } if (callbackResult !== false) { const subPayload = typeof callbackResult === 'object' ? callbackResult : item; - await this.visitSubPayload(model, action, subPayload, context.nestingPath); + await this.visitSubPayload(model, action, subPayload, newContext.nestingPath); } } break; case 'createMany': if (data) { - context.nestingPath.push({ field, model, where: {}, unique: false }); + const newContext = pushNewContext(field, model, {}); let callbackResult: any; if (this.callback.createMany) { - callbackResult = await this.callback.createMany(model, data, context); + callbackResult = await this.callback.createMany(model, data, newContext); } if (callbackResult !== false) { const subPayload = typeof callbackResult === 'object' ? callbackResult : data.data; - await this.visitSubPayload(model, action, subPayload, context.nestingPath); + await this.visitSubPayload(model, action, subPayload, newContext.nestingPath); } } break; case 'connectOrCreate': - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item.where); let callbackResult: any; if (this.callback.connectOrCreate) { - callbackResult = await this.callback.connectOrCreate(model, item, context); + callbackResult = await this.callback.connectOrCreate(model, item, newContext); } if (callbackResult !== false) { const subPayload = typeof callbackResult === 'object' ? callbackResult : item.create; - await this.visitSubPayload(model, action, subPayload, context.nestingPath); + await this.visitSubPayload(model, action, subPayload, newContext.nestingPath); } } break; @@ -195,10 +199,7 @@ export class NestedWriteVisitor { case 'connect': if (this.callback.connect) { for (const item of enumerate(data)) { - const newContext = { - ...context, - nestingPath: [...context.nestingPath, { field, model, where: item, unique: true }], - }; + const newContext = pushNewContext(field, model, item, true); await this.callback.connect(model, item, newContext); } } @@ -210,13 +211,7 @@ export class NestedWriteVisitor { // if relation is to-one, the payload can only be boolean `true` if (this.callback.disconnect) { for (const item of enumerate(data)) { - const newContext = { - ...context, - nestingPath: [ - ...context.nestingPath, - { field, model, where: item, unique: typeof item === 'object' }, - ], - }; + const newContext = pushNewContext(field, model, item, typeof item === 'object'); await this.callback.disconnect(model, item, newContext); } } @@ -224,17 +219,17 @@ export class NestedWriteVisitor { case 'set': if (this.callback.set) { - context.nestingPath.push({ field, model, where: {}, unique: false }); - await this.callback.set(model, data, context); + const newContext = pushNewContext(field, model, {}); + await this.callback.set(model, data, newContext); } break; case 'update': - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item.where); let callbackResult: any; if (this.callback.update) { - callbackResult = await this.callback.update(model, item, context); + callbackResult = await this.callback.update(model, item, newContext); } if (callbackResult !== false) { const subPayload = @@ -243,38 +238,38 @@ export class NestedWriteVisitor { : typeof item.data === 'object' ? item.data : item; - await this.visitSubPayload(model, action, subPayload, context.nestingPath); + await this.visitSubPayload(model, action, subPayload, newContext.nestingPath); } } break; case 'updateMany': - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item.where); let callbackResult: any; if (this.callback.updateMany) { - callbackResult = await this.callback.updateMany(model, item, context); + callbackResult = await this.callback.updateMany(model, item, newContext); } if (callbackResult !== false) { const subPayload = typeof callbackResult === 'object' ? callbackResult : item; - await this.visitSubPayload(model, action, subPayload, context.nestingPath); + await this.visitSubPayload(model, action, subPayload, newContext.nestingPath); } } break; case 'upsert': { - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item.where); let callbackResult: any; if (this.callback.upsert) { - callbackResult = await this.callback.upsert(model, item, context); + callbackResult = await this.callback.upsert(model, item, newContext); } if (callbackResult !== false) { if (typeof callbackResult === 'object') { - await this.visitSubPayload(model, action, callbackResult, context.nestingPath); + await this.visitSubPayload(model, action, callbackResult, newContext.nestingPath); } else { - await this.visitSubPayload(model, action, item.create, context.nestingPath); - await this.visitSubPayload(model, action, item.update, context.nestingPath); + await this.visitSubPayload(model, action, item.create, newContext.nestingPath); + await this.visitSubPayload(model, action, item.update, newContext.nestingPath); } } } @@ -284,13 +279,7 @@ export class NestedWriteVisitor { case 'delete': { if (this.callback.delete) { for (const item of enumerate(data)) { - const newContext = { - ...context, - nestingPath: [ - ...context.nestingPath, - { field, model, where: toplevel ? item.where : item, unique: false }, - ], - }; + const newContext = pushNewContext(field, model, toplevel ? item.where : item); await this.callback.delete(model, item, newContext); } } @@ -300,13 +289,7 @@ export class NestedWriteVisitor { case 'deleteMany': if (this.callback.deleteMany) { for (const item of enumerate(data)) { - const newContext = { - ...context, - nestingPath: [ - ...context.nestingPath, - { field, model, where: toplevel ? item.where : item, unique: false }, - ], - }; + const newContext = pushNewContext(field, model, toplevel ? item.where : item); await this.callback.deleteMany(model, item, newContext); } } diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 43d5ba665..f002002d2 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { lowerCaseFirst } from 'lower-case-first'; +import invariant from 'tiny-invariant'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason, PRISMA_TX_FLAG } from '../../constants'; @@ -10,6 +11,7 @@ import { NestedWriteVisitorContext, enumerate, getIdFields, + requireField, resolveField, type FieldInfo, type ModelMeta, @@ -641,8 +643,9 @@ export class PolicyProxyHandler implements Pr // handles the connection to upstream entity const reversedQuery = this.utils.buildReversedQuery(context, true, unsafe); - if (reversedQuery[context.field.backLink]) { - // the built reverse query contains a condition for the backlink field, build a "connect" with it + if ((!unsafe || context.field.isRelationOwner) && reversedQuery[context.field.backLink]) { + // if mutation is safe, or current field owns the relation (so the other side has no fk), + // and the reverse query contains the back link, then we can build a "connect" with it createData = { ...createData, [context.field.backLink]: { @@ -650,11 +653,52 @@ export class PolicyProxyHandler implements Pr }, }; } else { - // otherwise, the reverse query is translated to foreign key setting, merge it to the create data - createData = { - ...createData, - ...reversedQuery, - }; + // otherwise, the reverse query should be translated to foreign key setting + // and merged to the create data + + const backLinkField = this.requireBackLink(context.field); + invariant(backLinkField.foreignKeyMapping); + + // try to extract foreign key values from the reverse query + let fkValues = Object.values(backLinkField.foreignKeyMapping).reduce((obj, fk) => { + obj[fk] = reversedQuery[fk]; + return obj; + }, {}); + + if (Object.values(fkValues).every((v) => v !== undefined)) { + // all foreign key values are available, merge them to the create data + createData = { + ...createData, + ...fkValues, + }; + } else { + // some foreign key values are missing, need to look up the upstream entity, + // this can happen when the upstream entity doesn't have a unique where clause, + // for example when it's nested inside a one-to-one update + const upstreamQuery = { + where: reversedQuery[backLinkField.name], + select: this.utils.makeIdSelection(backLinkField.type), + }; + + // fetch the upstream entity + if (this.logger.enabled('info')) { + this.logger.info( + `[policy] \`findUniqueOrThrow\` ${model}: looking up upstream entity of ${ + backLinkField.type + }, ${formatObject(upstreamQuery)}` + ); + } + const upstreamEntity = await this.prisma[backLinkField.type].findUniqueOrThrow(upstreamQuery); + + // map ids to foreign keys + fkValues = Object.entries(backLinkField.foreignKeyMapping).reduce((obj, [id, fk]) => { + obj[fk] = upstreamEntity[id]; + return obj; + }, {}); + + // merge them to the create data + createData = { ...createData, ...fkValues }; + } } } @@ -1192,7 +1236,7 @@ export class PolicyProxyHandler implements Pr // already in transaction, don't nest return action(this.prisma); } else { - return this.prisma.$transaction((tx) => action(tx)); + return this.prisma.$transaction((tx) => action(tx), { maxWait: 100000, timeout: 100000 }); } } @@ -1217,11 +1261,8 @@ export class PolicyProxyHandler implements Pr } private requireBackLink(fieldInfo: FieldInfo) { - const backLinkField = fieldInfo.backLink && resolveField(this.modelMeta, fieldInfo.type, fieldInfo.backLink); - if (!backLinkField) { - throw new Error('Missing back link for field: ' + fieldInfo.name); - } - return backLinkField; + invariant(fieldInfo.backLink, `back link not found for field ${fieldInfo.name}`); + return requireField(this.modelMeta, fieldInfo.type, fieldInfo.backLink); } //#endregion diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 9da09d8b5..51c19c0fb 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -441,6 +441,9 @@ importers: superjson: specifier: ^1.11.0 version: 1.11.0 + tiny-invariant: + specifier: ^1.3.1 + version: 1.3.1 tslib: specifier: ^2.4.1 version: 2.4.1 diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index f65b05f69..126c038fa 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -26,6 +26,7 @@ describe('With Policy: refactor tests', () => { { provider: 'postgresql', dbUrl, + logPrismaQuery: true, } ); getDb = withPolicy; @@ -455,7 +456,7 @@ describe('With Policy: refactor tests', () => { await expect(user1Db.post.findFirst({ where: { id: { in: [4, 5] } } })).toResolveNull(); }); - it('update', async () => { + it('update single', async () => { await prisma.user.create({ data: { id: 2, @@ -643,7 +644,7 @@ describe('With Policy: refactor tests', () => { }, }) ).toResolveTruthy(); - expect( + await expect( user1Db.user.update({ include: { posts: true }, where: { id: 1 }, @@ -799,7 +800,7 @@ describe('With Policy: refactor tests', () => { upsert: { where: { id: 1 }, update: { title: 'Post 1-1' }, // update - create: { id: 1, title: 'Post 1' }, + create: { id: 7, title: 'Post 1' }, }, }, }, @@ -814,7 +815,7 @@ describe('With Policy: refactor tests', () => { upsert: { where: { id: 7 }, update: { title: 'Post 7-1' }, - create: { id: 1, title: 'Post 7' }, // create + create: { id: 7, title: 'Post 7' }, // create }, }, }, @@ -843,7 +844,7 @@ describe('With Policy: refactor tests', () => { upsert: { where: { id: 7 }, update: { title: 'Post 7 very long' }, - create: { id: 1, title: 'Post 7' }, + create: { title: 'Post 7' }, }, }, }, @@ -1098,7 +1099,7 @@ describe('With Policy: refactor tests', () => { ).resolves.toMatchObject({ count: 2 }); }); - it('delete', async () => { + it('delete single', async () => { await prisma.user.create({ data: { id: 1, diff --git a/tests/integration/tests/regression/issue-864.test.ts b/tests/integration/tests/regression/issue-864.test.ts new file mode 100644 index 000000000..02aab81d1 --- /dev/null +++ b/tests/integration/tests/regression/issue-864.test.ts @@ -0,0 +1,185 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Regression: issue nested create', () => { + it('safe create', async () => { + const { prisma, enhance } = await loadSchema( + ` + model A { + id Int @id @default(autoincrement()) + aValue Int + b B[] + + @@allow('all', aValue > 0) + } + + model B { + id Int @id @default(autoincrement()) + bValue Int + aId Int + a A @relation(fields: [aId], references: [id]) + c C[] + + @@allow('all', bValue > 0) + } + + model C { + id Int @id @default(autoincrement()) + cValue Int + bId Int + b B @relation(fields: [bId], references: [id]) + + @@allow('all', cValue > 0) + } + ` + ); + + await prisma.a.create({ + data: { id: 1, aValue: 1, b: { create: { id: 2, bValue: 2 } } }, + include: { b: true }, + }); + + const db = enhance(); + await db.a.update({ + where: { id: 1 }, + data: { + b: { + update: [ + { + where: { id: 2 }, + data: { + c: { + create: [ + { + cValue: 3, + }, + ], + }, + }, + }, + ], + }, + }, + }); + }); + + it('unsafe create nested in to-many', async () => { + const { prisma, enhance } = await loadSchema( + ` + model A { + id Int @id @default(autoincrement()) + aValue Int + b B[] + + @@allow('all', aValue > 0) + } + + model B { + id Int @id @default(autoincrement()) + bValue Int + aId Int + a A @relation(fields: [aId], references: [id]) + c C[] + + @@allow('all', bValue > 0) + } + + model C { + id Int @id @default(autoincrement()) + cValue Int + bId Int + b B @relation(fields: [bId], references: [id]) + + @@allow('all', cValue > 0) + } + ` + ); + + await prisma.a.create({ + data: { id: 1, aValue: 1, b: { create: { id: 2, bValue: 2 } } }, + include: { b: true }, + }); + + const db = enhance(); + await db.a.update({ + where: { id: 1 }, + data: { + b: { + update: [ + { + where: { id: 2 }, + data: { + c: { + create: [ + { + id: 1, + cValue: 3, + }, + ], + }, + }, + }, + ], + }, + }, + }); + }); + + it('unsafe create nested in to-one', async () => { + const { prisma, enhance } = await loadSchema( + ` + model A { + id Int @id @default(autoincrement()) + aValue Int + b B? + + @@allow('all', aValue > 0) + } + + model B { + id Int @id @default(autoincrement()) + bValue Int + aId Int @unique + a A @relation(fields: [aId], references: [id]) + c C[] + + @@allow('all', bValue > 0) + } + + model C { + id Int @id @default(autoincrement()) + cValue Int + bId Int + b B @relation(fields: [bId], references: [id]) + + @@allow('all', cValue > 0) + } + ` + ); + + await prisma.a.create({ + data: { id: 1, aValue: 1, b: { create: { id: 2, bValue: 2 } } }, + include: { b: true }, + }); + + const db = enhance(); + await db.a.update({ + where: { id: 1 }, + data: { + b: { + update: { + data: { + c: { + create: [ + { + id: 1, + cValue: 3, + }, + ], + }, + }, + }, + }, + }, + }); + }); +});