Skip to content

Commit 018d59f

Browse files
authored
feat: allow to pass in a custom Prisma module when calling enhance (#1160)
1 parent 269809a commit 018d59f

File tree

5 files changed

+90
-35
lines changed

5 files changed

+90
-35
lines changed

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import { lowerCaseFirst } from 'lower-case-first';
44
import invariant from 'tiny-invariant';
55
import { upperCaseFirst } from 'upper-case-first';
66
import { fromZodError } from 'zod-validation-error';
7+
import type { WithPolicyOptions } from '.';
78
import { CrudFailureReason } from '../../constants';
89
import {
910
ModelDataVisitor,
@@ -23,7 +24,6 @@ import { formatObject, prismaClientValidationError } from '../utils';
2324
import { Logger } from './logger';
2425
import { PolicyUtil } from './policy-utils';
2526
import { createDeferredPromise } from './promise';
26-
import { WithPolicyOptions } from '.';
2727

2828
// a record for post-write policy check
2929
type PostWriteCheckRecord = {
@@ -58,6 +58,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
5858
this.logger = new Logger(prisma);
5959
this.utils = new PolicyUtil(
6060
this.prisma,
61+
this.options,
6162
this.modelMeta,
6263
this.policy,
6364
this.zodSchemas,
@@ -77,20 +78,20 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
7778

7879
findUnique(args: any) {
7980
if (!args) {
80-
throw prismaClientValidationError(this.prisma, 'query argument is required');
81+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
8182
}
8283
if (!args.where) {
83-
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
84+
throw prismaClientValidationError(this.prisma, this.options, 'where field is required in query argument');
8485
}
8586
return this.findWithFluentCallStubs(args, 'findUnique', false, () => null);
8687
}
8788

8889
findUniqueOrThrow(args: any) {
8990
if (!args) {
90-
throw prismaClientValidationError(this.prisma, 'query argument is required');
91+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
9192
}
9293
if (!args.where) {
93-
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
94+
throw prismaClientValidationError(this.prisma, this.options, 'where field is required in query argument');
9495
}
9596
return this.findWithFluentCallStubs(args, 'findUniqueOrThrow', true, () => {
9697
throw this.utils.notFound(this.model);
@@ -220,10 +221,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
220221

221222
async create(args: any) {
222223
if (!args) {
223-
throw prismaClientValidationError(this.prisma, 'query argument is required');
224+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
224225
}
225226
if (!args.data) {
226-
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
227+
throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument');
227228
}
228229

229230
this.utils.tryReject(this.prisma, this.model, 'create');
@@ -476,10 +477,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
476477

477478
async createMany(args: { data: any; skipDuplicates?: boolean }) {
478479
if (!args) {
479-
throw prismaClientValidationError(this.prisma, 'query argument is required');
480+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
480481
}
481482
if (!args.data) {
482-
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
483+
throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument');
483484
}
484485

485486
this.utils.tryReject(this.prisma, this.model, 'create');
@@ -596,13 +597,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
596597

597598
async update(args: any) {
598599
if (!args) {
599-
throw prismaClientValidationError(this.prisma, 'query argument is required');
600+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
600601
}
601602
if (!args.where) {
602-
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
603+
throw prismaClientValidationError(this.prisma, this.options, 'where field is required in query argument');
603604
}
604605
if (!args.data) {
605-
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
606+
throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument');
606607
}
607608

608609
args = this.utils.clone(args);
@@ -1071,10 +1072,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10711072

10721073
async updateMany(args: any) {
10731074
if (!args) {
1074-
throw prismaClientValidationError(this.prisma, 'query argument is required');
1075+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
10751076
}
10761077
if (!args.data) {
1077-
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
1078+
throw prismaClientValidationError(this.prisma, this.options, 'data field is required in query argument');
10781079
}
10791080

10801081
this.utils.tryReject(this.prisma, this.model, 'update');
@@ -1130,16 +1131,16 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11301131

11311132
async upsert(args: any) {
11321133
if (!args) {
1133-
throw prismaClientValidationError(this.prisma, 'query argument is required');
1134+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
11341135
}
11351136
if (!args.where) {
1136-
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
1137+
throw prismaClientValidationError(this.prisma, this.options, 'where field is required in query argument');
11371138
}
11381139
if (!args.create) {
1139-
throw prismaClientValidationError(this.prisma, 'create field is required in query argument');
1140+
throw prismaClientValidationError(this.prisma, this.options, 'create field is required in query argument');
11401141
}
11411142
if (!args.update) {
1142-
throw prismaClientValidationError(this.prisma, 'update field is required in query argument');
1143+
throw prismaClientValidationError(this.prisma, this.options, 'update field is required in query argument');
11431144
}
11441145

11451146
this.utils.tryReject(this.prisma, this.model, 'create');
@@ -1183,10 +1184,10 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
11831184

11841185
async delete(args: any) {
11851186
if (!args) {
1186-
throw prismaClientValidationError(this.prisma, 'query argument is required');
1187+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
11871188
}
11881189
if (!args.where) {
1189-
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
1190+
throw prismaClientValidationError(this.prisma, this.options, 'where field is required in query argument');
11901191
}
11911192

11921193
this.utils.tryReject(this.prisma, this.model, 'delete');
@@ -1239,7 +1240,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
12391240

12401241
async aggregate(args: any) {
12411242
if (!args) {
1242-
throw prismaClientValidationError(this.prisma, 'query argument is required');
1243+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
12431244
}
12441245

12451246
args = this.utils.clone(args);
@@ -1255,7 +1256,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
12551256

12561257
async groupBy(args: any) {
12571258
if (!args) {
1258-
throw prismaClientValidationError(this.prisma, 'query argument is required');
1259+
throw prismaClientValidationError(this.prisma, this.options, 'query argument is required');
12591260
}
12601261

12611262
args = this.utils.clone(args);
@@ -1299,7 +1300,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
12991300
args = { create: {}, update: {}, delete: {} };
13001301
} else {
13011302
if (typeof args !== 'object') {
1302-
throw prismaClientValidationError(this.prisma, 'argument must be an object');
1303+
throw prismaClientValidationError(this.prisma, this.options, 'argument must be an object');
13031304
}
13041305
if (Object.keys(args).length === 0) {
13051306
// include all

packages/runtime/src/enhancements/policy/policy-utils.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { lowerCaseFirst } from 'lower-case-first';
55
import { upperCaseFirst } from 'upper-case-first';
66
import { ZodError } from 'zod';
77
import { fromZodError } from 'zod-validation-error';
8+
import type { EnhancementOptions } from '..';
89
import {
910
CrudFailureReason,
1011
FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX,
@@ -48,6 +49,7 @@ export class PolicyUtil {
4849

4950
constructor(
5051
private readonly db: DbClientContract,
52+
private readonly options: EnhancementOptions | undefined,
5153
private readonly modelMeta: ModelMeta,
5254
private readonly policy: PolicyDef,
5355
private readonly zodSchemas: ZodSchemas | undefined,
@@ -1098,24 +1100,25 @@ export class PolicyUtil {
10981100

10991101
return prismaClientKnownRequestError(
11001102
this.db,
1103+
this.options,
11011104
`denied by policy: ${model} entities failed '${operation}' check${extra ? ', ' + extra : ''}`,
11021105
args
11031106
);
11041107
}
11051108

11061109
notFound(model: string) {
1107-
return prismaClientKnownRequestError(this.db, `entity not found for model ${model}`, {
1110+
return prismaClientKnownRequestError(this.db, this.options, `entity not found for model ${model}`, {
11081111
clientVersion: getVersion(),
11091112
code: 'P2025',
11101113
});
11111114
}
11121115

11131116
validationError(message: string) {
1114-
return prismaClientValidationError(this.db, message);
1117+
return prismaClientValidationError(this.db, this.options, message);
11151118
}
11161119

11171120
unknownError(message: string) {
1118-
return prismaClientUnknownRequestError(this.db, message, {
1121+
return prismaClientUnknownRequestError(this.db, this.options, message, {
11191122
clientVersion: getVersion(),
11201123
});
11211124
}

packages/runtime/src/enhancements/types.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ export interface CommonEnhancementOptions {
1919
* Path for loading CLI-generated code
2020
*/
2121
loadPath?: string;
22+
23+
/**
24+
* The `Prisma` module generated together with `PrismaClient`. You only need to
25+
* pass it when you specified a custom `PrismaClient` output path. The module can
26+
* be loaded like: `import { Prisma } from '<your PrismaClient output path>';`.
27+
*/
28+
prismaModule?: any;
2229
}
2330

2431
/**

packages/runtime/src/enhancements/utils.ts

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import path from 'path';
44
import * as util from 'util';
55
import type { DbClientContract } from '../types';
6+
import type { EnhancementOptions } from './enhance';
67

78
/**
89
* Formats an object for pretty printing.
@@ -53,25 +54,37 @@ function loadPrismaModule(prisma: any) {
5354
}
5455
}
5556

56-
export function prismaClientValidationError(prisma: DbClientContract, message: string) {
57+
export function prismaClientValidationError(
58+
prisma: DbClientContract,
59+
options: EnhancementOptions | undefined,
60+
message: string
61+
) {
5762
if (!_PrismaClientValidationError) {
58-
const _prisma = loadPrismaModule(prisma);
63+
const _prisma = options?.prismaModule ?? loadPrismaModule(prisma);
5964
_PrismaClientValidationError = _prisma.PrismaClientValidationError;
6065
}
6166
throw new _PrismaClientValidationError(message, { clientVersion: prisma._clientVersion });
6267
}
6368

64-
export function prismaClientKnownRequestError(prisma: DbClientContract, ...args: unknown[]) {
69+
export function prismaClientKnownRequestError(
70+
prisma: DbClientContract,
71+
options: EnhancementOptions | undefined,
72+
...args: unknown[]
73+
) {
6574
if (!_PrismaClientKnownRequestError) {
66-
const _prisma = loadPrismaModule(prisma);
75+
const _prisma = options?.prismaModule ?? loadPrismaModule(prisma);
6776
_PrismaClientKnownRequestError = _prisma.PrismaClientKnownRequestError;
6877
}
6978
return new _PrismaClientKnownRequestError(...args);
7079
}
7180

72-
export function prismaClientUnknownRequestError(prisma: DbClientContract, ...args: unknown[]) {
81+
export function prismaClientUnknownRequestError(
82+
prisma: DbClientContract,
83+
options: EnhancementOptions | undefined,
84+
...args: unknown[]
85+
) {
7386
if (!_PrismaClientUnknownRequestError) {
74-
const _prisma = loadPrismaModule(prisma);
87+
const _prisma = options?.prismaModule ?? loadPrismaModule(prisma);
7588
_PrismaClientUnknownRequestError = _prisma.PrismaClientUnknownRequestError;
7689
}
7790
throw new _PrismaClientUnknownRequestError(...args);

tests/integration/tests/enhancements/with-policy/options.test.ts

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { withPolicy } from '@zenstackhq/runtime';
1+
import { enhance } from '@zenstackhq/runtime';
22
import { loadSchema } from '@zenstackhq/testtools';
33
import path from 'path';
44

@@ -20,17 +20,48 @@ describe('Password test', () => {
2020
id String @id @default(cuid())
2121
x Int
2222
23+
@@allow('read', true)
2324
@@allow('create', x > 0)
2425
}`,
2526
{ getPrismaOnly: true, output: './zen' }
2627
);
2728

28-
const db = withPolicy(prisma, undefined, { loadPath: './zen' });
29+
const db = enhance(prisma, undefined, { loadPath: './zen' });
2930
await expect(
3031
db.foo.create({
3132
data: { x: 0 },
3233
})
3334
).toBeRejectedByPolicy();
35+
await expect(
36+
db.foo.create({
37+
data: { x: 1 },
38+
})
39+
).toResolveTruthy();
40+
});
41+
42+
it('prisma module', async () => {
43+
const { prisma, Prisma, modelMeta, policy } = await loadSchema(
44+
`
45+
model Foo {
46+
id String @id @default(cuid())
47+
x Int
48+
49+
@@allow('read', true)
50+
@@allow('create', x > 0)
51+
}`
52+
);
53+
54+
const db = enhance(prisma, undefined, { modelMeta, policy, prismaModule: Prisma });
55+
await expect(
56+
db.foo.create({
57+
data: { x: 0 },
58+
})
59+
).toBeRejectedByPolicy();
60+
await expect(
61+
db.foo.create({
62+
data: { x: 1 },
63+
})
64+
).toResolveTruthy();
3465
});
3566

3667
it('overrides', async () => {
@@ -45,7 +76,7 @@ describe('Password test', () => {
4576
{ getPrismaOnly: true, output: './zen' }
4677
);
4778

48-
const db = withPolicy(prisma, undefined, {
79+
const db = enhance(prisma, undefined, {
4980
modelMeta: require(path.resolve('./zen/model-meta')).default,
5081
policy: require(path.resolve('./zen/policy')).default,
5182
});

0 commit comments

Comments
 (0)