diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index dc3e9cf30..cd098706c 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -622,8 +622,18 @@ export class PolicyProxyHandler implements Pr const _create = async (model: string, args: any, context: NestedWriteVisitorContext) => { let createData = args; if (context.field?.backLink) { + // Check if the create payload contains any "unsafe" assignment: + // assign id or foreign key fields. + // + // The reason why we need to do that is Prisma's mutations payload + // structure has two mutually exclusive forms for safe and unsafe + // operations. E.g.: + // - safe: { data: { user: { connect: { id: 1 }} } } + // - unsafe: { data: { userId: 1 } } + const unsafe = this.isUnsafeMutate(model, args); + // handles the connection to upstream entity - const reversedQuery = this.utils.buildReversedQuery(context); + const reversedQuery = this.utils.buildReversedQuery(context, true, unsafe); if (reversedQuery[context.field.backLink]) { // the built reverse query contains a condition for the backlink field, build a "connect" with it createData = { @@ -881,6 +891,19 @@ export class PolicyProxyHandler implements Pr return { result, postWriteChecks }; } + private isUnsafeMutate(model: string, args: any) { + if (!args) { + return false; + } + for (const k of Object.keys(args)) { + const field = resolveField(this.modelMeta, model, k); + if (field?.isId || field?.isForeignKey) { + return true; + } + } + return false; + } + async updateMany(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index a27d689b4..12c8c57be 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -484,7 +484,7 @@ export class PolicyUtil { /** * Builds a reversed query for the given nested path. */ - buildReversedQuery(context: NestedWriteVisitorContext) { + buildReversedQuery(context: NestedWriteVisitorContext, mutating = false, unsafeOperation = false) { let result, currQuery: any; let currField: FieldInfo | undefined; @@ -509,19 +509,41 @@ export class PolicyUtil { if (!currField.backLink) { throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); } + const backLinkField = this.getModelField(currField.type, currField.backLink); - if (backLinkField?.isArray) { + if (!backLinkField) { + throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); + } + + if (backLinkField.isArray) { // many-side of relationship, wrap with "some" query currQuery[currField.backLink] = { some: { ...visitWhere } }; } else { - if (where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping) { - for (const [r, fk] of Object.entries(backLinkField.foreignKeyMapping)) { + const fkMapping = where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping; + + // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) + const shouldPreserveRelationCondition = + // doing a mutation + mutating && + // and it's a safe mutate + !unsafeOperation && + // and the current segment is the direct parent (the last one is the mutate itself), + // the relation condition should be preserved and will be converted to a "connect" later + i === context.nestingPath.length - 2; + + if (fkMapping && !shouldPreserveRelationCondition) { + // turn relation condition into foreign key condition, e.g.: + // { user: { id: 1 } } => { userId: 1 } + for (const [r, fk] of Object.entries(fkMapping)) { currQuery[fk] = visitWhere[r]; } + if (i > 0) { + // prepare for the next segment currQuery[currField.backLink] = {}; } } else { + // preserve the original structure currQuery[currField.backLink] = { ...visitWhere }; } } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index b09b5052c..445703991 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -110,6 +110,11 @@ export type FieldInfo = { */ isRelationOwner: boolean; + /** + * If the field is a foreign key field + */ + isForeignKey: boolean; + /** * Mapping from foreign key field names to relation field names */ diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index 8c4432db5..2c0148751 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -19,6 +19,7 @@ import { getDataModels, getLiteral, hasAttribute, + isForeignKeyField, isIdField, PluginError, PluginFunction, @@ -95,6 +96,7 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter) attributes: ${JSON.stringify(getFieldAttributes(f))}, backLink: ${backlink ? "'" + backlink.name + "'" : 'undefined'}, isRelationOwner: ${isRelationOwner(f, backlink)}, + isForeignKey: ${isForeignKeyField(f)}, foreignKeyMapping: ${fkMapping ? JSON.stringify(fkMapping) : 'undefined'} },`); } diff --git a/tests/integration/tests/regression/issue-714.test.ts b/tests/integration/tests/regression/issue-714.test.ts new file mode 100644 index 000000000..673d3a689 --- /dev/null +++ b/tests/integration/tests/regression/issue-714.test.ts @@ -0,0 +1,168 @@ +import { createPostgresDb, dropPostgresDb, loadSchema } from '@zenstackhq/testtools'; + +const DB_NAME = 'issue-714'; + +describe('Regression: issue 714', () => { + let dbUrl: string; + let prisma: any; + + beforeEach(async () => { + dbUrl = await createPostgresDb(DB_NAME); + }); + + afterEach(async () => { + if (prisma) { + await prisma.$disconnect(); + } + await dropPostgresDb(DB_NAME); + }); + + it('regression', async () => { + const { prisma: _prisma, enhance } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + username String @unique + + employedBy CompanyUser[] + properties PropertyUser[] + companies Company[] + + @@allow('all', true) + } + + model Company { + id Int @id @default(autoincrement()) + name String + + companyUsers CompanyUser[] + propertyUsers User[] + properties Property[] + + @@allow('all', true) + } + + model CompanyUser { + company Company @relation(fields: [companyId], references: [id]) + companyId Int + user User @relation(fields: [userId], references: [id]) + userId Int + + dummyField String + + @@id([companyId, userId]) + + @@allow('all', true) + } + + enum PropertyUserRoleType { + Owner + Administrator + } + + model PropertyUserRole { + id Int @id @default(autoincrement()) + type PropertyUserRoleType + + user PropertyUser @relation(fields: [userId], references: [id]) + userId Int + + @@allow('all', true) + } + + model PropertyUser { + id Int @id @default(autoincrement()) + dummyField String + + property Property @relation(fields: [propertyId], references: [id]) + propertyId Int + user User @relation(fields: [userId], references: [id]) + userId Int + + roles PropertyUserRole[] + + @@unique([propertyId, userId]) + + @@allow('all', true) + } + + model Property { + id Int @id @default(autoincrement()) + name String + + users PropertyUser[] + company Company @relation(fields: [companyId], references: [id]) + companyId Int + + @@allow('all', true) + } + `, + { + provider: 'postgresql', + dbUrl, + } + ); + + prisma = _prisma; + const db = enhance(); + + await db.user.create({ + data: { + username: 'test@example.com', + }, + }); + + await db.company.create({ + data: { + name: 'My Company', + companyUsers: { + create: { + dummyField: '', + user: { + connect: { + id: 1, + }, + }, + }, + }, + propertyUsers: { + connect: { + id: 1, + }, + }, + properties: { + create: [ + { + name: 'Test', + }, + ], + }, + }, + }); + + await db.property.update({ + data: { + users: { + create: { + dummyField: '', + roles: { + createMany: { + data: { + type: 'Owner', + }, + }, + }, + user: { + connect: { + id: 1, + }, + }, + }, + }, + }, + where: { + id: 1, + }, + }); + }); +});