Skip to content

Commit 4ae5a96

Browse files
authored
feat: fluent API support (#666)
1 parent b44976d commit 4ae5a96

File tree

21 files changed

+423
-193
lines changed

21 files changed

+423
-193
lines changed

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 151 additions & 85 deletions
Large diffs are not rendered by default.

packages/runtime/src/enhancements/policy/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ export type WithPolicyOptions = {
2929
policy?: PolicyDef;
3030

3131
/**
32-
* Model metatadata
32+
* Model metadata
3333
*/
3434
modelMeta?: ModelMeta;
3535

packages/runtime/src/enhancements/policy/policy-utils.ts

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ export class PolicyUtil {
253253
/**
254254
* Injects model auth guard as where clause.
255255
*/
256-
async injectAuthGuard(db: Record<string, DbOperations>, args: any, model: string, operation: PolicyOperationKind) {
256+
injectAuthGuard(db: Record<string, DbOperations>, args: any, model: string, operation: PolicyOperationKind) {
257257
let guard = this.getAuthGuard(db, model, operation);
258258
if (this.isFalse(guard)) {
259259
args.where = this.makeFalse();
@@ -277,14 +277,14 @@ export class PolicyUtil {
277277
// inject into relation fields:
278278
// to-many: some/none/every
279279
// to-one: direct-conditions/is/isNot
280-
await this.injectGuardForRelationFields(db, model, args.where, operation);
280+
this.injectGuardForRelationFields(db, model, args.where, operation);
281281
}
282282

283283
args.where = this.and(args.where, guard);
284284
return true;
285285
}
286286

287-
private async injectGuardForRelationFields(
287+
private injectGuardForRelationFields(
288288
db: Record<string, DbOperations>,
289289
model: string,
290290
payload: any,
@@ -295,33 +295,33 @@ export class PolicyUtil {
295295
continue;
296296
}
297297

298-
const fieldInfo = await resolveField(this.modelMeta, model, field);
298+
const fieldInfo = resolveField(this.modelMeta, model, field);
299299
if (!fieldInfo || !fieldInfo.isDataModel) {
300300
continue;
301301
}
302302

303303
if (fieldInfo.isArray) {
304-
await this.injectGuardForToManyField(db, fieldInfo, subPayload, operation);
304+
this.injectGuardForToManyField(db, fieldInfo, subPayload, operation);
305305
} else {
306-
await this.injectGuardForToOneField(db, fieldInfo, subPayload, operation);
306+
this.injectGuardForToOneField(db, fieldInfo, subPayload, operation);
307307
}
308308
}
309309
}
310310

311-
private async injectGuardForToManyField(
311+
private injectGuardForToManyField(
312312
db: Record<string, DbOperations>,
313313
fieldInfo: FieldInfo,
314314
payload: { some?: any; every?: any; none?: any },
315315
operation: PolicyOperationKind
316316
) {
317317
const guard = this.getAuthGuard(db, fieldInfo.type, operation);
318318
if (payload.some) {
319-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation);
319+
this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation);
320320
// turn "some" into: { some: { AND: [guard, payload.some] } }
321321
payload.some = this.and(payload.some, guard);
322322
}
323323
if (payload.none) {
324-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation);
324+
this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation);
325325
// turn none into: { none: { AND: [guard, payload.none] } }
326326
payload.none = this.and(payload.none, guard);
327327
}
@@ -331,7 +331,7 @@ export class PolicyUtil {
331331
// ignore empty every clause
332332
Object.keys(payload.every).length > 0
333333
) {
334-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation);
334+
this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation);
335335

336336
// turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } }
337337
if (!payload.none) {
@@ -342,7 +342,7 @@ export class PolicyUtil {
342342
}
343343
}
344344

345-
private async injectGuardForToOneField(
345+
private injectGuardForToOneField(
346346
db: Record<string, DbOperations>,
347347
fieldInfo: FieldInfo,
348348
payload: { is?: any; isNot?: any } & Record<string, any>,
@@ -351,18 +351,18 @@ export class PolicyUtil {
351351
const guard = this.getAuthGuard(db, fieldInfo.type, operation);
352352
if (payload.is || payload.isNot) {
353353
if (payload.is) {
354-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation);
354+
this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation);
355355
// turn "is" into: { is: { AND: [ originalIs, guard ] }
356356
payload.is = this.and(payload.is, guard);
357357
}
358358
if (payload.isNot) {
359-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation);
359+
this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation);
360360
// turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } }
361361
payload.isNot = this.and(payload.isNot, this.not(guard));
362362
delete payload.isNot;
363363
}
364364
} else {
365-
await this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation);
365+
this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation);
366366
// turn direct conditions into: { is: { AND: [ originalConditions, guard ] } }
367367
const combined = this.and(deepcopy(payload), guard);
368368
Object.keys(payload).forEach((key) => delete payload[key]);
@@ -373,17 +373,17 @@ export class PolicyUtil {
373373
/**
374374
* Injects auth guard for read operations.
375375
*/
376-
async injectForRead(db: Record<string, DbOperations>, model: string, args: any) {
376+
injectForRead(db: Record<string, DbOperations>, model: string, args: any) {
377377
const injected: any = {};
378-
if (!(await this.injectAuthGuard(db, injected, model, 'read'))) {
378+
if (!this.injectAuthGuard(db, injected, model, 'read')) {
379379
return false;
380380
}
381381

382382
if (args.where) {
383383
// inject into relation fields:
384384
// to-many: some/none/every
385385
// to-one: direct-conditions/is/isNot
386-
await this.injectGuardForRelationFields(db, model, args.where, 'read');
386+
this.injectGuardForRelationFields(db, model, args.where, 'read');
387387
}
388388

389389
if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) {
@@ -395,7 +395,7 @@ export class PolicyUtil {
395395
}
396396

397397
// recursively inject read guard conditions into nested select, include, and _count
398-
const hoistedConditions = await this.injectNestedReadConditions(db, model, args);
398+
const hoistedConditions = this.injectNestedReadConditions(db, model, args);
399399

400400
// the injection process may generate conditions that need to be hoisted to the toplevel,
401401
// if so, merge it with the existing where
@@ -441,7 +441,7 @@ export class PolicyUtil {
441441
/**
442442
* Builds a reversed query for the given nested path.
443443
*/
444-
async buildReversedQuery(context: NestedWriteVisitorContext) {
444+
buildReversedQuery(context: NestedWriteVisitorContext) {
445445
let result, currQuery: any;
446446
let currField: FieldInfo | undefined;
447447

@@ -489,11 +489,7 @@ export class PolicyUtil {
489489
return result;
490490
}
491491

492-
private async injectNestedReadConditions(
493-
db: Record<string, DbOperations>,
494-
model: string,
495-
args: any
496-
): Promise<any[]> {
492+
private injectNestedReadConditions(db: Record<string, DbOperations>, model: string, args: any): any[] {
497493
const injectTarget = args.select ?? args.include;
498494
if (!injectTarget) {
499495
return [];
@@ -526,7 +522,7 @@ export class PolicyUtil {
526522
continue;
527523
}
528524
// inject into the "where" clause inside select
529-
await this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read');
525+
this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read');
530526
}
531527
}
532528

@@ -552,10 +548,10 @@ export class PolicyUtil {
552548
injectTarget[field] = {};
553549
}
554550
// inject extra condition for to-many or nullable to-one relation
555-
await this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read');
551+
this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read');
556552

557553
// recurse
558-
const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
554+
const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
559555
if (subHoisted.length > 0) {
560556
// we can convert it to a where at this level
561557
injectTarget[field].where = this.and(injectTarget[field].where, ...subHoisted);
@@ -564,7 +560,7 @@ export class PolicyUtil {
564560
// hoist non-nullable to-one filter to the parent level
565561
hoisted = this.getAuthGuard(db, fieldInfo.type, 'read');
566562
// recurse
567-
const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
563+
const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
568564
if (subHoisted.length > 0) {
569565
hoisted = this.and(hoisted, ...subHoisted);
570566
}
@@ -732,7 +728,7 @@ export class PolicyUtil {
732728
CrudFailureReason.RESULT_NOT_READABLE
733729
);
734730

735-
const injectResult = await this.injectForRead(db, model, readArgs);
731+
const injectResult = this.injectForRead(db, model, readArgs);
736732
if (!injectResult) {
737733
return { error, result: undefined };
738734
}
@@ -1011,6 +1007,14 @@ export class PolicyUtil {
10111007
}
10121008
}
10131009

1010+
/**
1011+
* Gets information for all fields of a model.
1012+
*/
1013+
getModelFields(model: string) {
1014+
model = lowerCaseFirst(model);
1015+
return this.modelMeta.fields[model];
1016+
}
1017+
10141018
/**
10151019
* Gets information for a specific model field.
10161020
*/
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/* eslint-disable @typescript-eslint/no-explicit-any */
2+
3+
/**
4+
* Creates a promise that only executes when it's awaited or .then() is called.
5+
* @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts
6+
*/
7+
export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T> {
8+
let promise: Promise<T> | undefined;
9+
const cb = () => {
10+
try {
11+
return (promise ??= valueToPromise(callback()));
12+
} catch (err) {
13+
// deal with synchronous errors
14+
return Promise.reject<T>(err);
15+
}
16+
};
17+
18+
return {
19+
then(onFulfilled, onRejected) {
20+
return cb().then(onFulfilled, onRejected);
21+
},
22+
catch(onRejected) {
23+
return cb().catch(onRejected);
24+
},
25+
finally(onFinally) {
26+
return cb().finally(onFinally);
27+
},
28+
[Symbol.toStringTag]: 'ZenStackPromise',
29+
};
30+
}
31+
32+
function valueToPromise(thing: any): Promise<any> {
33+
if (typeof thing === 'object' && typeof thing?.then === 'function') {
34+
return thing;
35+
} else {
36+
return Promise.resolve(thing);
37+
}
38+
}

packages/runtime/src/enhancements/proxy.ts

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22

3-
import { PRISMA_TX_FLAG, PRISMA_PROXY_ENHANCER } from '../constants';
3+
import { PRISMA_PROXY_ENHANCER, PRISMA_TX_FLAG } from '../constants';
44
import { DbClientContract } from '../types';
5+
import { createDeferredPromise } from './policy/promise';
56
import { ModelMeta } from './types';
67

78
/**
@@ -174,11 +175,7 @@ export function makeProxy<T extends PrismaProxyHandler>(
174175
modelMeta: ModelMeta,
175176
makeHandler: (prisma: object, model: string) => T,
176177
name = 'unnamed_enhancer'
177-
// inTransaction = false
178178
) {
179-
// // put a transaction marker on the proxy target
180-
// prisma[PRISIMA_TX_FLAG] = inTransaction;
181-
182179
const models = Object.keys(modelMeta.fields).map((k) => k.toLowerCase());
183180
const proxy = new Proxy(prisma, {
184181
get: (target: any, prop: string | symbol, receiver: any) => {
@@ -248,20 +245,39 @@ function createHandlerProxy<T extends PrismaProxyHandler>(handler: T): T {
248245

249246
// eslint-disable-next-line @typescript-eslint/ban-types
250247
const origMethod = prop as Function;
251-
return async function (...args: any[]) {
252-
// proxying async functions results in messed-up error stack trace,
248+
return function (...args: any[]) {
249+
// using proxy with async functions results in messed-up error stack trace,
253250
// create an error to capture the current stack
254251
const capture = new Error(ERROR_MARKER);
255-
try {
256-
return await origMethod.apply(handler, args);
257-
} catch (err) {
258-
if (capture.stack && err instanceof Error) {
259-
// save the original stack and replace it with a clean one
260-
(err as any).internalStack = err.stack;
261-
err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message);
252+
253+
// the original proxy returned by the PrismaClient proxy
254+
const promise: Promise<unknown> = origMethod.apply(handler, args);
255+
256+
// modify the error stack
257+
const resultPromise = createDeferredPromise(() => {
258+
return new Promise((resolve, reject) => {
259+
promise.then(
260+
(value) => resolve(value),
261+
(err) => {
262+
if (capture.stack && err instanceof Error) {
263+
// save the original stack and replace it with a clean one
264+
(err as any).internalStack = err.stack;
265+
err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message);
266+
}
267+
reject(err);
268+
}
269+
);
270+
});
271+
});
272+
273+
// carry over extra fields from the original promise
274+
for (const [k, v] of Object.entries(promise)) {
275+
if (!(k in resultPromise)) {
276+
(resultPromise as any)[k] = v;
262277
}
263-
throw err;
264278
}
279+
280+
return resultPromise;
265281
};
266282
},
267283
});
@@ -287,7 +303,7 @@ function cleanCallStack(stack: string, method: string, message: string) {
287303
}
288304

289305
// skip leading zenstack and anonymous lines
290-
if (line.includes('@zenstackhq/runtime') || line.includes('<anonymous>')) {
306+
if (line.includes('@zenstackhq/runtime') || line.includes('Proxy.<anonymous>')) {
291307
continue;
292308
}
293309

packages/runtime/src/types.ts

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,27 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
22

3+
export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;
4+
35
/**
46
* Weakly-typed database access methods
57
*/
68
export interface DbOperations {
7-
findMany(args?: unknown): Promise<unknown[]>;
8-
findFirst(args: unknown): Promise<unknown>;
9-
findFirstOrThrow(args: unknown): Promise<unknown>;
10-
findUnique(args: unknown): Promise<unknown>;
11-
findUniqueOrThrow(args: unknown): Promise<unknown>;
12-
create(args: unknown): Promise<unknown>;
9+
findMany(args?: unknown): Promise<any[]>;
10+
findFirst(args?: unknown): PrismaPromise<any>;
11+
findFirstOrThrow(args?: unknown): PrismaPromise<any>;
12+
findUnique(args: unknown): PrismaPromise<any>;
13+
findUniqueOrThrow(args: unknown): PrismaPromise<any>;
14+
create(args: unknown): Promise<any>;
1315
createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>;
14-
update(args: unknown): Promise<unknown>;
16+
update(args: unknown): Promise<any>;
1517
updateMany(args: unknown): Promise<{ count: number }>;
16-
upsert(args: unknown): Promise<unknown>;
17-
delete(args: unknown): Promise<unknown>;
18+
upsert(args: unknown): Promise<any>;
19+
delete(args: unknown): Promise<any>;
1820
deleteMany(args?: unknown): Promise<{ count: number }>;
19-
aggregate(args: unknown): Promise<unknown>;
20-
groupBy(args: unknown): Promise<unknown>;
21-
count(args?: unknown): Promise<unknown>;
22-
subscribe(args?: unknown): Promise<unknown>;
21+
aggregate(args: unknown): Promise<any>;
22+
groupBy(args: unknown): Promise<any>;
23+
count(args?: unknown): Promise<any>;
24+
subscribe(args?: unknown): Promise<any>;
2325
fields: Record<string, any>;
2426
}
2527

packages/schema/src/utils/version-utils.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@ export function getVersion() {
33
try {
44
return require('../package.json').version;
55
} catch {
6-
// dev environment
7-
return require('../../package.json').version;
6+
try {
7+
// dev environment
8+
return require('../../package.json').version;
9+
} catch {
10+
return undefined;
11+
}
812
}
913
}

0 commit comments

Comments
 (0)