Skip to content

Commit 2b2bfcf

Browse files
authored
fix: issue 961, incorrect policy injection for nested updateMany (#962)
2 parents 3b9a6c4 + a079add commit 2b2bfcf

File tree

20 files changed

+484
-63
lines changed

20 files changed

+484
-63
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "zenstack-monorepo",
3-
"version": "1.7.0",
3+
"version": "1.7.1",
44
"description": "",
55
"scripts": {
66
"build": "pnpm -r build",

packages/ide/jetbrains/build.gradle.kts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ plugins {
55
}
66

77
group = "dev.zenstack"
8-
version = "1.7.0"
8+
version = "1.7.1"
99

1010
repositories {
1111
mavenCentral()

packages/ide/jetbrains/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "jetbrains",
3-
"version": "1.7.0",
3+
"version": "1.7.1",
44
"displayName": "ZenStack JetBrains IDE Plugin",
55
"description": "ZenStack JetBrains IDE plugin",
66
"homepage": "https://zenstack.dev",

packages/language/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@zenstackhq/language",
3-
"version": "1.7.0",
3+
"version": "1.7.1",
44
"displayName": "ZenStack modeling language compiler",
55
"description": "ZenStack modeling language compiler",
66
"homepage": "https://zenstack.dev",

packages/misc/redwood/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/redwood",
33
"displayName": "ZenStack RedwoodJS Integration",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "CLI and runtime for integrating ZenStack with RedwoodJS projects.",
66
"repository": {
77
"type": "git",

packages/plugins/openapi/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/openapi",
33
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "ZenStack plugin and runtime supporting OpenAPI",
66
"main": "index.js",
77
"repository": {

packages/plugins/swr/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/swr",
33
"displayName": "ZenStack plugin for generating SWR hooks",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "ZenStack plugin for generating SWR hooks",
66
"main": "index.js",
77
"repository": {

packages/plugins/tanstack-query/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/tanstack-query",
33
"displayName": "ZenStack plugin for generating tanstack-query hooks",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "ZenStack plugin for generating tanstack-query hooks",
66
"main": "index.js",
77
"exports": {

packages/plugins/trpc/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/trpc",
33
"displayName": "ZenStack plugin for tRPC",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "ZenStack plugin for tRPC",
66
"main": "index.js",
77
"repository": {

packages/runtime/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/runtime",
33
"displayName": "ZenStack Runtime Library",
4-
"version": "1.7.0",
4+
"version": "1.7.1",
55
"description": "Runtime of ZenStack for both client-side and server-side environments.",
66
"repository": {
77
"type": "git",

packages/runtime/src/cross/model-meta.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ export type FieldInfo = {
6666
* Mapping from foreign key field names to relation field names
6767
*/
6868
foreignKeyMapping?: Record<string, string>;
69+
70+
/**
71+
* If the field is an auto-increment field
72+
*/
73+
isAutoIncrement?: boolean;
6974
};
7075

7176
/**

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -523,29 +523,16 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
523523
let createResult = await Promise.all(
524524
enumerate(args.data).map(async (item) => {
525525
if (args.skipDuplicates) {
526-
// check unique constraint conflicts
527-
// we can't rely on try/catch/ignore constraint violation error: https://github.com/prisma/prisma/issues/20496
528-
// TODO: for simple cases we should be able to translate it to an `upsert` with empty `update` payload
529-
530-
// for each unique constraint, check if the input item has all fields set, and if so, check if
531-
// an entity already exists, and ignore accordingly
532-
const uniqueConstraints = this.utils.getUniqueConstraints(model);
533-
for (const constraint of Object.values(uniqueConstraints)) {
534-
if (constraint.fields.every((f) => item[f] !== undefined)) {
535-
const uniqueFilter = constraint.fields.reduce((acc, f) => ({ ...acc, [f]: item[f] }), {});
536-
const existing = await this.utils.checkExistence(db, model, uniqueFilter);
537-
if (existing) {
538-
if (this.shouldLogQuery) {
539-
this.logger.info(`[policy] skipping duplicate ${formatObject(item)}`);
540-
}
541-
return undefined;
542-
}
526+
if (await this.hasDuplicatedUniqueConstraint(model, item, db)) {
527+
if (this.shouldLogQuery) {
528+
this.logger.info(`[policy] \`createMany\` skipping duplicate ${formatObject(item)}`);
543529
}
530+
return undefined;
544531
}
545532
}
546533

547534
if (this.shouldLogQuery) {
548-
this.logger.info(`[policy] \`create\` ${model}: ${formatObject(item)}`);
535+
this.logger.info(`[policy] \`create\` for \`createMany\` ${model}: ${formatObject(item)}`);
549536
}
550537
return await db[model].create({ select: this.utils.makeIdSelection(model), data: item });
551538
})
@@ -564,6 +551,26 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
564551
};
565552
}
566553

554+
private async hasDuplicatedUniqueConstraint(model: string, createData: any, db: Record<string, DbOperations>) {
555+
// check unique constraint conflicts
556+
// we can't rely on try/catch/ignore constraint violation error: https://github.com/prisma/prisma/issues/20496
557+
// TODO: for simple cases we should be able to translate it to an `upsert` with empty `update` payload
558+
559+
// for each unique constraint, check if the input item has all fields set, and if so, check if
560+
// an entity already exists, and ignore accordingly
561+
const uniqueConstraints = this.utils.getUniqueConstraints(model);
562+
for (const constraint of Object.values(uniqueConstraints)) {
563+
if (constraint.fields.every((f) => createData[f] !== undefined)) {
564+
const uniqueFilter = constraint.fields.reduce((acc, f) => ({ ...acc, [f]: createData[f] }), {});
565+
const existing = await this.utils.checkExistence(db, model, uniqueFilter);
566+
if (existing) {
567+
return true;
568+
}
569+
}
570+
}
571+
return false;
572+
}
573+
567574
//#endregion
568575

569576
//#region Update & Upsert
@@ -707,17 +714,22 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
707714
postWriteChecks.push(...checks);
708715
};
709716

710-
const _createMany = async (model: string, args: any, context: NestedWriteVisitorContext) => {
711-
if (context.field?.backLink) {
712-
// handles the connection to upstream entity
713-
const reversedQuery = this.utils.buildReversedQuery(context);
714-
for (const item of enumerate(args.data)) {
715-
Object.assign(item, reversedQuery);
717+
const _createMany = async (
718+
model: string,
719+
args: { data: any; skipDuplicates?: boolean },
720+
context: NestedWriteVisitorContext
721+
) => {
722+
for (const item of enumerate(args.data)) {
723+
if (args.skipDuplicates) {
724+
if (await this.hasDuplicatedUniqueConstraint(model, item, db)) {
725+
if (this.shouldLogQuery) {
726+
this.logger.info(`[policy] \`createMany\` skipping duplicate ${formatObject(item)}`);
727+
}
728+
continue;
729+
}
716730
}
731+
await _create(model, item, context);
717732
}
718-
// proceed with the create and collect post-create checks
719-
const { postWriteChecks: checks } = await this.doCreateMany(model, args, db);
720-
postWriteChecks.push(...checks);
721733
};
722734

723735
const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => {
@@ -797,9 +809,6 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
797809
},
798810

799811
updateMany: async (model, args, context) => {
800-
// injects auth guard into where clause
801-
this.utils.injectAuthGuard(db, args, model, 'update');
802-
803812
// prepare for post-update check
804813
if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) {
805814
let select = this.utils.makeIdSelection(model);
@@ -809,10 +818,12 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
809818
}
810819
const reversedQuery = this.utils.buildReversedQuery(context);
811820
const currentSetQuery = { select, where: reversedQuery };
812-
this.utils.injectAuthGuard(db, currentSetQuery, model, 'read');
821+
this.utils.injectAuthGuardAsWhere(db, currentSetQuery, model, 'read');
813822

814823
if (this.shouldLogQuery) {
815-
this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(currentSetQuery)}`);
824+
this.logger.info(
825+
`[policy] \`findMany\` for post update check ${model}:\n${formatObject(currentSetQuery)}`
826+
);
816827
}
817828
const currentSet = await db[model].findMany(currentSetQuery);
818829

@@ -825,6 +836,27 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
825836
}))
826837
);
827838
}
839+
840+
const updateGuard = this.utils.getAuthGuard(db, model, 'update');
841+
if (this.utils.isTrue(updateGuard) || this.utils.isFalse(updateGuard)) {
842+
// injects simple auth guard into where clause
843+
this.utils.injectAuthGuardAsWhere(db, args, model, 'update');
844+
} else {
845+
// we have to process `updateMany` separately because the guard may contain
846+
// filters using relation fields which are not allowed in nested `updateMany`
847+
const reversedQuery = this.utils.buildReversedQuery(context);
848+
const updateWhere = this.utils.and(reversedQuery, updateGuard);
849+
if (this.shouldLogQuery) {
850+
this.logger.info(
851+
`[policy] \`updateMany\` ${model}:\n${formatObject({
852+
where: updateWhere,
853+
data: args.data,
854+
})}`
855+
);
856+
}
857+
await db[model].updateMany({ where: updateWhere, data: args.data });
858+
delete context.parent.updateMany;
859+
}
828860
},
829861

830862
create: async (model, args, context) => {
@@ -931,9 +963,21 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
931963
},
932964

933965
deleteMany: async (model, args, context) => {
934-
// inject delete guard
935966
const guard = await this.utils.getAuthGuard(db, model, 'delete');
936-
context.parent.deleteMany = this.utils.and(args, guard);
967+
if (this.utils.isTrue(guard) || this.utils.isFalse(guard)) {
968+
// inject simple auth guard
969+
context.parent.deleteMany = this.utils.and(args, guard);
970+
} else {
971+
// we have to process `deleteMany` separately because the guard may contain
972+
// filters using relation fields which are not allowed in nested `deleteMany`
973+
const reversedQuery = this.utils.buildReversedQuery(context);
974+
const deleteWhere = this.utils.and(reversedQuery, guard);
975+
if (this.shouldLogQuery) {
976+
this.logger.info(`[policy] \`deleteMany\` ${model}:\n${formatObject({ where: deleteWhere })}`);
977+
}
978+
await db[model].deleteMany({ where: deleteWhere });
979+
delete context.parent.deleteMany;
980+
}
937981
},
938982
});
939983

@@ -958,13 +1002,17 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9581002
}
9591003
for (const k of Object.keys(args)) {
9601004
const field = resolveField(this.modelMeta, model, k);
961-
if (field?.isId || field?.isForeignKey) {
1005+
if (this.isAutoIncrementIdField(field) || field?.isForeignKey) {
9621006
return true;
9631007
}
9641008
}
9651009
return false;
9661010
}
9671011

1012+
private isAutoIncrementIdField(field: FieldInfo) {
1013+
return field.isId && field.isAutoIncrement;
1014+
}
1015+
9681016
async updateMany(args: any) {
9691017
if (!args) {
9701018
throw prismaClientValidationError(this.prisma, 'query argument is required');
@@ -976,7 +1024,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9761024
this.utils.tryReject(this.prisma, this.model, 'update');
9771025

9781026
args = this.utils.clone(args);
979-
this.utils.injectAuthGuard(this.prisma, args, this.model, 'update');
1027+
this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update');
9801028

9811029
if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) {
9821030
// use a transaction to do post-update checks
@@ -989,7 +1037,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
9891037
select = { ...select, ...preValueSelect };
9901038
}
9911039
const currentSetQuery = { select, where: args.where };
992-
this.utils.injectAuthGuard(tx, currentSetQuery, this.model, 'read');
1040+
this.utils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read');
9931041

9941042
if (this.shouldLogQuery) {
9951043
this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`);
@@ -1118,7 +1166,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11181166

11191167
// inject policy conditions
11201168
args = args ?? {};
1121-
this.utils.injectAuthGuard(this.prisma, args, this.model, 'delete');
1169+
this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete');
11221170

11231171
// conduct the deletion
11241172
if (this.shouldLogQuery) {
@@ -1139,7 +1187,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11391187
args = this.utils.clone(args);
11401188

11411189
// inject policy conditions
1142-
this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');
1190+
this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read');
11431191

11441192
if (this.shouldLogQuery) {
11451193
this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`);
@@ -1155,7 +1203,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11551203
args = this.utils.clone(args);
11561204

11571205
// inject policy conditions
1158-
this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');
1206+
this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read');
11591207

11601208
if (this.shouldLogQuery) {
11611209
this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`);
@@ -1166,7 +1214,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11661214
async count(args: any) {
11671215
// inject policy conditions
11681216
args = args ? this.utils.clone(args) : {};
1169-
this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');
1217+
this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read');
11701218

11711219
if (this.shouldLogQuery) {
11721220
this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`);

0 commit comments

Comments
 (0)