diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 727038fd4..ad8d2aa76 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -20,6 +20,11 @@ export type RuntimeAttribute = { */ export type FieldDefaultValueProvider = (userContext: unknown) => unknown; +/** + * Action to take when the related model is deleted or updated + */ +export type RelationAction = 'Cascade' | 'Restrict' | 'NoAction' | 'SetNull' | 'SetDefault'; + /** * Runtime information of a data model field */ @@ -74,6 +79,16 @@ export type FieldInfo = { */ isRelationOwner?: boolean; + /** + * Action to take when the related model is deleted. + */ + onDeleteAction?: RelationAction; + + /** + * Action to take when the related model is updated. + */ + onUpdateAction?: RelationAction; + /** * If the field is a foreign key field */ diff --git a/packages/runtime/src/enhancements/node/delegate.ts b/packages/runtime/src/enhancements/node/delegate.ts index 3a4c0a585..b9cc2c033 100644 --- a/packages/runtime/src/enhancements/node/delegate.ts +++ b/packages/runtime/src/enhancements/node/delegate.ts @@ -1160,19 +1160,90 @@ export class DelegateProxyHandler extends DefaultPrismaProxyHandler { } } - private async doDelete(db: CrudContract, model: string, args: any): Promise { + private async doDelete(db: CrudContract, model: string, args: any, readBack = true): Promise { this.injectWhereHierarchy(model, args.where); await this.injectSelectIncludeHierarchy(model, args); + // read relation entities that need to be cascade deleted before deleting the main entity + const cascadeDeletes = await this.getRelationDelegateEntitiesForCascadeDelete(db, model, args.where); + + let result: unknown = undefined; + if (cascadeDeletes.length > 0) { + // we'll need to do cascade deletes of relations, so first + // read the current entity before anything changes + if (readBack) { + result = await this.doFind(db, model, 'findUnique', args); + } + + // process cascade deletes of relations, this ensure their delegate base + // entities are deleted as well + await Promise.all( + cascadeDeletes.map(({ model, entity }) => this.doDelete(db, model, { where: entity }, false)) + ); + } + if (this.options.logPrismaQuery) { this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`); } - const result = await db[model].delete(args); - const idValues = this.queryUtils.getEntityIds(model, result); + + const deleteResult = await db[model].delete(args); + if (!result) { + result = this.assembleHierarchy(model, deleteResult); + } // recursively delete base entities (they all have the same id values) + const idValues = this.queryUtils.getEntityIds(model, deleteResult); await this.deleteBaseRecursively(db, model, idValues); - return this.assembleHierarchy(model, result); + + return result; + } + + private async getRelationDelegateEntitiesForCascadeDelete(db: CrudContract, model: string, where: any) { + if (!where || Object.keys(where).length === 0) { + throw new Error('where clause is required for cascade delete'); + } + + const cascadeDeletes: Array<{ model: string; entity: any }> = []; + const fields = getFields(this.options.modelMeta, model); + if (fields) { + for (const fieldInfo of Object.values(fields)) { + if (!fieldInfo.isDataModel) { + continue; + } + + if (fieldInfo.isRelationOwner) { + // this side of the relation owns the foreign key, + // so it won't cause cascade delete to the other side + continue; + } + + if (fieldInfo.backLink) { + // get the opposite side of the relation + const backLinkField = this.queryUtils.getModelField(fieldInfo.type, fieldInfo.backLink); + + if (backLinkField?.isRelationOwner && this.isFieldCascadeDelete(backLinkField)) { + // if the opposite side of the relation is to be cascade deleted, + // recursively delete the delegate base entities + const relationModel = getModelInfo(this.options.modelMeta, fieldInfo.type); + if (relationModel?.baseTypes && relationModel.baseTypes.length > 0) { + // the relation model has delegate base, cascade the delete to the base + const relationEntities = await db[relationModel.name].findMany({ + where: { [backLinkField.name]: where }, + select: this.queryUtils.makeIdSelection(relationModel.name), + }); + relationEntities.forEach((entity) => { + cascadeDeletes.push({ model: fieldInfo.type, entity }); + }); + } + } + } + } + } + return cascadeDeletes; + } + + private isFieldCascadeDelete(fieldInfo: FieldInfo) { + return fieldInfo.onDeleteAction === 'Cascade'; } // #endregion diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index c5b866417..716e6ad7d 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -311,6 +311,18 @@ function writeFields( isRelationOwner: true,`); } + const onDeleteAction = getOnDeleteAction(dmField); + if (onDeleteAction) { + writer.write(` + onDeleteAction: '${onDeleteAction}',`); + } + + const onUpdateAction = getOnUpdateAction(dmField); + if (onUpdateAction) { + writer.write(` + onUpdateAction: '${onUpdateAction}',`); + } + if (isForeignKeyField(dmField)) { writer.write(` isForeignKey: true,`); @@ -568,3 +580,25 @@ function writeShortNameMap(options: ModelMetaGeneratorOptions, writer: CodeWrite writer.write(','); } } + +function getOnDeleteAction(fieldInfo: DataModelField) { + const relationAttr = getAttribute(fieldInfo, '@relation'); + if (relationAttr) { + const onDelete = getAttributeArg(relationAttr, 'onDelete'); + if (onDelete && isEnumFieldReference(onDelete)) { + return onDelete.target.ref?.name; + } + } + return undefined; +} + +function getOnUpdateAction(fieldInfo: DataModelField) { + const relationAttr = getAttribute(fieldInfo, '@relation'); + if (relationAttr) { + const onUpdate = getAttributeArg(relationAttr, 'onUpdate'); + if (onUpdate && isEnumFieldReference(onUpdate)) { + return onUpdate.target.ref?.name; + } + } + return undefined; +} diff --git a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts index d35fc02c6..544198c32 100644 --- a/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts +++ b/tests/integration/tests/enhancements/with-delegate/enhanced-client.test.ts @@ -1057,7 +1057,7 @@ describe('Polymorphism Test', () => { expect(created.duration).toBe(300); }); - it('delete', async () => { + it('delete simple', async () => { let { db, user, video: ratedVideo } = await setup(); let deleted = await db.ratedVideo.delete({ @@ -1106,6 +1106,55 @@ describe('Polymorphism Test', () => { await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); }); + it('delete cascade', async () => { + const { prisma, enhance } = await loadSchema( + ` + model Base { + id Int @id @default(autoincrement()) + type String + @@delegate(type) + } + + model List extends Base { + name String + items Item[] + } + + model Item extends Base { + name String + list List @relation(fields: [listId], references: [id], onDelete: Cascade) + listId Int + content ItemContent? + } + + model ItemContent extends Base { + name String + item Item @relation(fields: [itemId], references: [id], onDelete: Cascade) + itemId Int @unique + } +`, + { enhancements: ['delegate'], logPrismaQuery: true } + ); + + const db = enhance(); + await db.list.create({ + data: { + id: 1, + name: 'list', + items: { + create: [{ id: 2, name: 'item1', content: { create: { id: 3, name: 'content1' } } }], + }, + }, + }); + + const r = await db.list.delete({ where: { id: 1 }, include: { items: { include: { content: true } } } }); + expect(r).toMatchObject({ items: [{ id: 2 }] }); + await expect(db.item.findUnique({ where: { id: 2 } })).toResolveNull(); + await expect(prisma.base.findUnique({ where: { id: 2 } })).toResolveNull(); + await expect(db.itemContent.findUnique({ where: { id: 3 } })).toResolveNull(); + await expect(prisma.base.findUnique({ where: { id: 3 } })).toResolveNull(); + }); + it('deleteMany', async () => { const { enhance } = await loadSchema(schema, { enhancements: ['delegate'] }); const db = enhance();