Skip to content

feat: fluent API support #666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 151 additions & 85 deletions packages/runtime/src/enhancements/policy/handler.ts

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion packages/runtime/src/enhancements/policy/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ export type WithPolicyOptions = {
policy?: PolicyDef;

/**
* Model metatadata
* Model metadata
*/
modelMeta?: ModelMeta;

Expand Down
62 changes: 33 additions & 29 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ export class PolicyUtil {
/**
* Injects model auth guard as where clause.
*/
async injectAuthGuard(db: Record<string, DbOperations>, args: any, model: string, operation: PolicyOperationKind) {
injectAuthGuard(db: Record<string, DbOperations>, args: any, model: string, operation: PolicyOperationKind) {
let guard = this.getAuthGuard(db, model, operation);
if (this.isFalse(guard)) {
args.where = this.makeFalse();
Expand All @@ -277,14 +277,14 @@ export class PolicyUtil {
// inject into relation fields:
// to-many: some/none/every
// to-one: direct-conditions/is/isNot
await this.injectGuardForRelationFields(db, model, args.where, operation);
this.injectGuardForRelationFields(db, model, args.where, operation);
}

args.where = this.and(args.where, guard);
return true;
}

private async injectGuardForRelationFields(
private injectGuardForRelationFields(
db: Record<string, DbOperations>,
model: string,
payload: any,
Expand All @@ -295,33 +295,33 @@ export class PolicyUtil {
continue;
}

const fieldInfo = await resolveField(this.modelMeta, model, field);
const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo || !fieldInfo.isDataModel) {
continue;
}

if (fieldInfo.isArray) {
await this.injectGuardForToManyField(db, fieldInfo, subPayload, operation);
this.injectGuardForToManyField(db, fieldInfo, subPayload, operation);
} else {
await this.injectGuardForToOneField(db, fieldInfo, subPayload, operation);
this.injectGuardForToOneField(db, fieldInfo, subPayload, operation);
}
}
}

private async injectGuardForToManyField(
private injectGuardForToManyField(
db: Record<string, DbOperations>,
fieldInfo: FieldInfo,
payload: { some?: any; every?: any; none?: any },
operation: PolicyOperationKind
) {
const guard = this.getAuthGuard(db, fieldInfo.type, operation);
if (payload.some) {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation);
// turn "some" into: { some: { AND: [guard, payload.some] } }
payload.some = this.and(payload.some, guard);
}
if (payload.none) {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation);
// turn none into: { none: { AND: [guard, payload.none] } }
payload.none = this.and(payload.none, guard);
}
Expand All @@ -331,7 +331,7 @@ export class PolicyUtil {
// ignore empty every clause
Object.keys(payload.every).length > 0
) {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation);

// turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } }
if (!payload.none) {
Expand All @@ -342,7 +342,7 @@ export class PolicyUtil {
}
}

private async injectGuardForToOneField(
private injectGuardForToOneField(
db: Record<string, DbOperations>,
fieldInfo: FieldInfo,
payload: { is?: any; isNot?: any } & Record<string, any>,
Expand All @@ -351,18 +351,18 @@ export class PolicyUtil {
const guard = this.getAuthGuard(db, fieldInfo.type, operation);
if (payload.is || payload.isNot) {
if (payload.is) {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation);
// turn "is" into: { is: { AND: [ originalIs, guard ] }
payload.is = this.and(payload.is, guard);
}
if (payload.isNot) {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation);
// turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } }
payload.isNot = this.and(payload.isNot, this.not(guard));
delete payload.isNot;
}
} else {
await this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation);
this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation);
// turn direct conditions into: { is: { AND: [ originalConditions, guard ] } }
const combined = this.and(deepcopy(payload), guard);
Object.keys(payload).forEach((key) => delete payload[key]);
Expand All @@ -373,17 +373,17 @@ export class PolicyUtil {
/**
* Injects auth guard for read operations.
*/
async injectForRead(db: Record<string, DbOperations>, model: string, args: any) {
injectForRead(db: Record<string, DbOperations>, model: string, args: any) {
const injected: any = {};
if (!(await this.injectAuthGuard(db, injected, model, 'read'))) {
if (!this.injectAuthGuard(db, injected, model, 'read')) {
return false;
}

if (args.where) {
// inject into relation fields:
// to-many: some/none/every
// to-one: direct-conditions/is/isNot
await this.injectGuardForRelationFields(db, model, args.where, 'read');
this.injectGuardForRelationFields(db, model, args.where, 'read');
}

if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) {
Expand All @@ -395,7 +395,7 @@ export class PolicyUtil {
}

// recursively inject read guard conditions into nested select, include, and _count
const hoistedConditions = await this.injectNestedReadConditions(db, model, args);
const hoistedConditions = this.injectNestedReadConditions(db, model, args);

// the injection process may generate conditions that need to be hoisted to the toplevel,
// if so, merge it with the existing where
Expand Down Expand Up @@ -441,7 +441,7 @@ export class PolicyUtil {
/**
* Builds a reversed query for the given nested path.
*/
async buildReversedQuery(context: NestedWriteVisitorContext) {
buildReversedQuery(context: NestedWriteVisitorContext) {
let result, currQuery: any;
let currField: FieldInfo | undefined;

Expand Down Expand Up @@ -489,11 +489,7 @@ export class PolicyUtil {
return result;
}

private async injectNestedReadConditions(
db: Record<string, DbOperations>,
model: string,
args: any
): Promise<any[]> {
private injectNestedReadConditions(db: Record<string, DbOperations>, model: string, args: any): any[] {
const injectTarget = args.select ?? args.include;
if (!injectTarget) {
return [];
Expand Down Expand Up @@ -526,7 +522,7 @@ export class PolicyUtil {
continue;
}
// inject into the "where" clause inside select
await this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read');
this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read');
}
}

Expand All @@ -552,10 +548,10 @@ export class PolicyUtil {
injectTarget[field] = {};
}
// inject extra condition for to-many or nullable to-one relation
await this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read');
this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read');

// recurse
const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
if (subHoisted.length > 0) {
// we can convert it to a where at this level
injectTarget[field].where = this.and(injectTarget[field].where, ...subHoisted);
Expand All @@ -564,7 +560,7 @@ export class PolicyUtil {
// hoist non-nullable to-one filter to the parent level
hoisted = this.getAuthGuard(db, fieldInfo.type, 'read');
// recurse
const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]);
if (subHoisted.length > 0) {
hoisted = this.and(hoisted, ...subHoisted);
}
Expand Down Expand Up @@ -732,7 +728,7 @@ export class PolicyUtil {
CrudFailureReason.RESULT_NOT_READABLE
);

const injectResult = await this.injectForRead(db, model, readArgs);
const injectResult = this.injectForRead(db, model, readArgs);
if (!injectResult) {
return { error, result: undefined };
}
Expand Down Expand Up @@ -1011,6 +1007,14 @@ export class PolicyUtil {
}
}

/**
* Gets information for all fields of a model.
*/
getModelFields(model: string) {
model = lowerCaseFirst(model);
return this.modelMeta.fields[model];
}

/**
* Gets information for a specific model field.
*/
Expand Down
38 changes: 38 additions & 0 deletions packages/runtime/src/enhancements/policy/promise.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

/**
* Creates a promise that only executes when it's awaited or .then() is called.
* @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts
*/
export function createDeferredPromise<T>(callback: () => Promise<T>): Promise<T> {
let promise: Promise<T> | undefined;
const cb = () => {
try {
return (promise ??= valueToPromise(callback()));
} catch (err) {
// deal with synchronous errors
return Promise.reject<T>(err);
}
};

return {
then(onFulfilled, onRejected) {
return cb().then(onFulfilled, onRejected);
},
catch(onRejected) {
return cb().catch(onRejected);
},
finally(onFinally) {
return cb().finally(onFinally);
},
[Symbol.toStringTag]: 'ZenStackPromise',
};
}

function valueToPromise(thing: any): Promise<any> {
if (typeof thing === 'object' && typeof thing?.then === 'function') {
return thing;
} else {
return Promise.resolve(thing);
}
}
48 changes: 32 additions & 16 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { PRISMA_TX_FLAG, PRISMA_PROXY_ENHANCER } from '../constants';
import { PRISMA_PROXY_ENHANCER, PRISMA_TX_FLAG } from '../constants';
import { DbClientContract } from '../types';
import { createDeferredPromise } from './policy/promise';
import { ModelMeta } from './types';

/**
Expand Down Expand Up @@ -174,11 +175,7 @@ export function makeProxy<T extends PrismaProxyHandler>(
modelMeta: ModelMeta,
makeHandler: (prisma: object, model: string) => T,
name = 'unnamed_enhancer'
// inTransaction = false
) {
// // put a transaction marker on the proxy target
// prisma[PRISIMA_TX_FLAG] = inTransaction;

const models = Object.keys(modelMeta.fields).map((k) => k.toLowerCase());
const proxy = new Proxy(prisma, {
get: (target: any, prop: string | symbol, receiver: any) => {
Expand Down Expand Up @@ -248,20 +245,39 @@ function createHandlerProxy<T extends PrismaProxyHandler>(handler: T): T {

// eslint-disable-next-line @typescript-eslint/ban-types
const origMethod = prop as Function;
return async function (...args: any[]) {
// proxying async functions results in messed-up error stack trace,
return function (...args: any[]) {
// using proxy with async functions results in messed-up error stack trace,
// create an error to capture the current stack
const capture = new Error(ERROR_MARKER);
try {
return await origMethod.apply(handler, args);
} catch (err) {
if (capture.stack && err instanceof Error) {
// save the original stack and replace it with a clean one
(err as any).internalStack = err.stack;
err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message);

// the original proxy returned by the PrismaClient proxy
const promise: Promise<unknown> = origMethod.apply(handler, args);

// modify the error stack
const resultPromise = createDeferredPromise(() => {
return new Promise((resolve, reject) => {
promise.then(
(value) => resolve(value),
(err) => {
if (capture.stack && err instanceof Error) {
// save the original stack and replace it with a clean one
(err as any).internalStack = err.stack;
err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message);
}
reject(err);
}
);
});
});

// carry over extra fields from the original promise
for (const [k, v] of Object.entries(promise)) {
if (!(k in resultPromise)) {
(resultPromise as any)[k] = v;
}
throw err;
}

return resultPromise;
};
},
});
Expand All @@ -287,7 +303,7 @@ function cleanCallStack(stack: string, method: string, message: string) {
}

// skip leading zenstack and anonymous lines
if (line.includes('@zenstackhq/runtime') || line.includes('<anonymous>')) {
if (line.includes('@zenstackhq/runtime') || line.includes('Proxy.<anonymous>')) {
continue;
}

Expand Down
28 changes: 15 additions & 13 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

export type PrismaPromise<T> = Promise<T> & Record<string, (args?: any) => PrismaPromise<any>>;

/**
* Weakly-typed database access methods
*/
export interface DbOperations {
findMany(args?: unknown): Promise<unknown[]>;
findFirst(args: unknown): Promise<unknown>;
findFirstOrThrow(args: unknown): Promise<unknown>;
findUnique(args: unknown): Promise<unknown>;
findUniqueOrThrow(args: unknown): Promise<unknown>;
create(args: unknown): Promise<unknown>;
findMany(args?: unknown): Promise<any[]>;
findFirst(args?: unknown): PrismaPromise<any>;
findFirstOrThrow(args?: unknown): PrismaPromise<any>;
findUnique(args: unknown): PrismaPromise<any>;
findUniqueOrThrow(args: unknown): PrismaPromise<any>;
create(args: unknown): Promise<any>;
createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>;
update(args: unknown): Promise<unknown>;
update(args: unknown): Promise<any>;
updateMany(args: unknown): Promise<{ count: number }>;
upsert(args: unknown): Promise<unknown>;
delete(args: unknown): Promise<unknown>;
upsert(args: unknown): Promise<any>;
delete(args: unknown): Promise<any>;
deleteMany(args?: unknown): Promise<{ count: number }>;
aggregate(args: unknown): Promise<unknown>;
groupBy(args: unknown): Promise<unknown>;
count(args?: unknown): Promise<unknown>;
subscribe(args?: unknown): Promise<unknown>;
aggregate(args: unknown): Promise<any>;
groupBy(args: unknown): Promise<any>;
count(args?: unknown): Promise<any>;
subscribe(args?: unknown): Promise<any>;
fields: Record<string, any>;
}

Expand Down
8 changes: 6 additions & 2 deletions packages/schema/src/utils/version-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ export function getVersion() {
try {
return require('../package.json').version;
} catch {
// dev environment
return require('../../package.json').version;
try {
// dev environment
return require('../../package.json').version;
} catch {
return undefined;
}
}
}
Loading