diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index dd3649e55..d91d6b88c 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,5 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; import { CrudFailureReason, PRISMA_TX_FLAG } from '../../constants'; @@ -12,6 +13,7 @@ import type { ModelMeta, PolicyDef, ZodSchemas } from '../types'; import { enumerate, formatObject, getIdFields, prismaClientValidationError } from '../utils'; import { Logger } from './logger'; import { PolicyUtil } from './policy-utils'; +import { createDeferredPromise } from './promise'; // a record for post-write policy check type PostWriteCheckRecord = { @@ -21,19 +23,22 @@ type PostWriteCheckRecord = { preValue?: any; }; +type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFirstOrThrow' | 'findMany'; + /** * Prisma proxy handler for injecting access policy check. */ export class PolicyProxyHandler implements PrismaProxyHandler { private readonly logger: Logger; private readonly utils: PolicyUtil; + private readonly model: string; constructor( private readonly prisma: DbClient, private readonly policy: PolicyDef, private readonly modelMeta: ModelMeta, private readonly zodSchemas: ZodSchemas | undefined, - private readonly model: string, + model: string, private readonly user?: AuthUser, private readonly logPrismaQuery?: boolean ) { @@ -46,6 +51,7 @@ export class PolicyProxyHandler implements Pr this.user, this.shouldLogQuery ); + this.model = lowerCaseFirst(model); } private get modelClient() { @@ -56,103 +62,143 @@ export class PolicyProxyHandler implements Pr // find operations behaves as if the entities that don't match access policies don't exist - async findUnique(args: any) { + findUnique(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); } if (!args.where) { throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - - const origArgs = args; - args = this.utils.clone(args); - if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { - return null; - } - - this.utils.injectReadCheckSelect(this.model, args); - - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); - } - const result = await this.modelClient.findUnique(args); - this.utils.postProcessForRead(result, this.model, origArgs); - return result; + return this.findWithFluentCallStubs(args, 'findUnique', false, () => null); } - async findUniqueOrThrow(args: any) { + findUniqueOrThrow(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); } if (!args.where) { throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - - const origArgs = args; - args = this.utils.clone(args); - if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { + return this.findWithFluentCallStubs(args, 'findUniqueOrThrow', true, () => { throw this.utils.notFound(this.model); - } - - this.utils.injectReadCheckSelect(this.model, args); + }); + } - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`); - } - const result = await this.modelClient.findUniqueOrThrow(args); - this.utils.postProcessForRead(result, this.model, origArgs); - return result; + findFirst(args?: any) { + return this.findWithFluentCallStubs(args, 'findFirst', false, () => null); } - async findFirst(args: any) { - const origArgs = args; - args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { - return null; - } + findFirstOrThrow(args: any) { + return this.findWithFluentCallStubs(args, 'findFirstOrThrow', true, () => { + throw this.utils.notFound(this.model); + }); + } - this.utils.injectReadCheckSelect(this.model, args); + findMany(args?: any) { + return createDeferredPromise(() => this.doFind(args, 'findMany', () => [])); + } - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`); - } - const result = await this.modelClient.findFirst(args); - this.utils.postProcessForRead(result, this.model, origArgs); + // returns a promise for the given find operation, together with function stubs for fluent API calls + private findWithFluentCallStubs( + args: any, + actionName: FindOperations, + resolveRoot: boolean, + handleRejection: () => any + ) { + // create a deferred promise so it's only evaluated when awaited or .then() is called + const result = createDeferredPromise(() => this.doFind(args, actionName, handleRejection)); + this.addFluentFunctions(result, this.model, args?.where, resolveRoot ? result : undefined); return result; } - async findFirstOrThrow(args: any) { + private doFind(args: any, actionName: FindOperations, handleRejection: () => any) { const origArgs = args; - args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { - throw this.utils.notFound(this.model); + const _args = this.utils.clone(args); + if (!this.utils.injectForRead(this.prisma, this.model, _args)) { + return handleRejection(); } - this.utils.injectReadCheckSelect(this.model, args); + this.utils.injectReadCheckSelect(this.model, _args); if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`); + this.logger.info(`[policy] \`${actionName}\` ${this.model}:\n${formatObject(_args)}`); } - const result = await this.modelClient.findFirstOrThrow(args); - this.utils.postProcessForRead(result, this.model, origArgs); - return result; + + return new Promise((resolve, reject) => { + this.modelClient[actionName](_args).then( + (value: any) => { + this.utils.postProcessForRead(value, this.model, origArgs); + resolve(value); + }, + (err: any) => reject(err) + ); + }); } - async findMany(args: any) { - const origArgs = args; - args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { - return []; - } + // returns a fluent API call function + private fluentCall(filter: any, fieldInfo: FieldInfo, rootPromise?: Promise) { + return (args: any) => { + args = this.utils.clone(args); + + // combine the parent filter with the current one + const backLinkField = this.requireBackLink(fieldInfo); + const condition = backLinkField.isArray + ? { [backLinkField.name]: { some: filter } } + : { [backLinkField.name]: { is: filter } }; + args.where = this.utils.and(args.where, condition); + + const promise = createDeferredPromise(() => { + // Promise for fetching + const fetchFluent = (resolve: (value: unknown) => void, reject: (reason?: any) => void) => { + const handler = this.makeHandler(fieldInfo.type); + if (fieldInfo.isArray) { + // fluent call stops here + handler.findMany(args).then( + (value: any) => resolve(value), + (err: any) => reject(err) + ); + } else { + handler.findFirst(args).then( + (value) => resolve(value), + (err) => reject(err) + ); + } + }; - this.utils.injectReadCheckSelect(this.model, args); + return new Promise((resolve, reject) => { + if (rootPromise) { + // if a root promise exists, resolve it before fluent API call, + // so that fluent calls start with `findUniqueOrThrow` and `findFirstOrThrow` + // can throw error properly if the root promise is rejected + rootPromise.then( + () => fetchFluent(resolve, reject), + (err) => reject(err) + ); + } else { + fetchFluent(resolve, reject); + } + }); + }); - if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`); + if (!fieldInfo.isArray) { + // prepare for a chained fluent API call + this.addFluentFunctions(promise, fieldInfo.type, args.where, rootPromise); + } + + return promise; + }; + } + + // add fluent API functions to the given promise + private addFluentFunctions(promise: any, model: string, filter: any, rootPromise?: Promise) { + const fields = this.utils.getModelFields(model); + if (fields) { + for (const [field, fieldInfo] of Object.entries(fields)) { + if (fieldInfo.isDataModel) { + promise[field] = this.fluentCall(filter, fieldInfo, rootPromise); + } + } } - const result = await this.modelClient.findMany(args); - this.utils.postProcessForRead(result, this.model, origArgs); - return result; } //#endregion @@ -167,7 +213,7 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.utils.tryReject(this.prisma, this.model, 'create'); + this.utils.tryReject(this.prisma, this.model, 'create'); const origArgs = args; args = this.utils.clone(args); @@ -571,7 +617,7 @@ export class PolicyProxyHandler implements Pr let createData = args; if (context.field?.backLink) { // handles the connection to upstream entity - const reversedQuery = await this.utils.buildReversedQuery(context); + const reversedQuery = this.utils.buildReversedQuery(context); if (reversedQuery[context.field.backLink]) { // the built reverse query contains a condition for the backlink field, build a "connect" with it createData = { @@ -597,7 +643,7 @@ export class PolicyProxyHandler implements Pr const _createMany = async (model: string, args: any, context: NestedWriteVisitorContext) => { if (context.field?.backLink) { // handles the connection to upstream entity - const reversedQuery = await this.utils.buildReversedQuery(context); + const reversedQuery = this.utils.buildReversedQuery(context); for (const item of enumerate(args.data)) { Object.assign(item, reversedQuery); } @@ -624,7 +670,7 @@ export class PolicyProxyHandler implements Pr const visitor = new NestedWriteVisitor(this.modelMeta, { update: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = await this.utils.buildReversedQuery(context); + const uniqueFilter = this.utils.buildReversedQuery(context); // handle not-found const existing = await this.utils.checkExistence(db, model, uniqueFilter, true); @@ -675,7 +721,7 @@ export class PolicyProxyHandler implements Pr updateMany: async (model, args, context) => { // injects auth guard into where clause - await this.utils.injectAuthGuard(db, args, model, 'update'); + this.utils.injectAuthGuard(db, args, model, 'update'); // prepare for post-update check if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { @@ -684,9 +730,9 @@ export class PolicyProxyHandler implements Pr if (preValueSelect) { select = { ...select, ...preValueSelect }; } - const reversedQuery = await this.utils.buildReversedQuery(context); + const reversedQuery = this.utils.buildReversedQuery(context); const currentSetQuery = { select, where: reversedQuery }; - await this.utils.injectAuthGuard(db, currentSetQuery, model, 'read'); + this.utils.injectAuthGuard(db, currentSetQuery, model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(currentSetQuery)}`); @@ -728,7 +774,7 @@ export class PolicyProxyHandler implements Pr upsert: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = await this.utils.buildReversedQuery(context); + const uniqueFilter = this.utils.buildReversedQuery(context); // branch based on if the update target exists const existing = await this.utils.checkExistence(db, model, uniqueFilter); @@ -779,7 +825,7 @@ export class PolicyProxyHandler implements Pr set: async (model, args, context) => { // find the set of items to be replaced - const reversedQuery = await this.utils.buildReversedQuery(context); + const reversedQuery = this.utils.buildReversedQuery(context); const findCurrSetArgs = { select: this.utils.makeIdSelection(model), where: reversedQuery, @@ -798,7 +844,7 @@ export class PolicyProxyHandler implements Pr delete: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = await this.utils.buildReversedQuery(context); + const uniqueFilter = this.utils.buildReversedQuery(context); // handle not-found await this.utils.checkExistence(db, model, uniqueFilter, true); @@ -837,10 +883,10 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.utils.tryReject(this.prisma, this.model, 'update'); + this.utils.tryReject(this.prisma, this.model, 'update'); args = this.utils.clone(args); - await this.utils.injectAuthGuard(this.prisma, args, this.model, 'update'); + this.utils.injectAuthGuard(this.prisma, args, this.model, 'update'); if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { // use a transaction to do post-update checks @@ -853,7 +899,7 @@ export class PolicyProxyHandler implements Pr select = { ...select, ...preValueSelect }; } const currentSetQuery = { select, where: args.where }; - await this.utils.injectAuthGuard(tx, currentSetQuery, this.model, 'read'); + this.utils.injectAuthGuard(tx, currentSetQuery, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); @@ -900,8 +946,8 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'update field is required in query argument'); } - await this.utils.tryReject(this.prisma, this.model, 'create'); - await this.utils.tryReject(this.prisma, this.model, 'update'); + this.utils.tryReject(this.prisma, this.model, 'create'); + this.utils.tryReject(this.prisma, this.model, 'update'); args = this.utils.clone(args); @@ -947,7 +993,7 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - await this.utils.tryReject(this.prisma, this.model, 'delete'); + this.utils.tryReject(this.prisma, this.model, 'delete'); const { result, error } = await this.transaction(async (tx) => { // do a read-back before delete @@ -978,11 +1024,11 @@ export class PolicyProxyHandler implements Pr } async deleteMany(args: any) { - await this.utils.tryReject(this.prisma, this.model, 'delete'); + this.utils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions args = args ?? {}; - await this.utils.injectAuthGuard(this.prisma, args, this.model, 'delete'); + this.utils.injectAuthGuard(this.prisma, args, this.model, 'delete'); // conduct the deletion if (this.shouldLogQuery) { @@ -1003,7 +1049,7 @@ export class PolicyProxyHandler implements Pr args = this.utils.clone(args); // inject policy conditions - await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); + this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); @@ -1019,7 +1065,7 @@ export class PolicyProxyHandler implements Pr args = this.utils.clone(args); // inject policy conditions - await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); + this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); @@ -1030,7 +1076,7 @@ export class PolicyProxyHandler implements Pr async count(args: any) { // inject policy conditions args = args ? this.utils.clone(args) : {}; - await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); + this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); @@ -1112,5 +1158,25 @@ export class PolicyProxyHandler implements Pr ); } + private makeHandler(model: string) { + return new PolicyProxyHandler( + this.prisma, + this.policy, + this.modelMeta, + this.zodSchemas, + model, + this.user, + this.logPrismaQuery + ); + } + + private requireBackLink(fieldInfo: FieldInfo) { + const backLinkField = fieldInfo.backLink && resolveField(this.modelMeta, fieldInfo.type, fieldInfo.backLink); + if (!backLinkField) { + throw new Error('Missing back link for field: ' + fieldInfo.name); + } + return backLinkField; + } + //#endregion } diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index 3da47b86a..afd548750 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -29,7 +29,7 @@ export type WithPolicyOptions = { policy?: PolicyDef; /** - * Model metatadata + * Model metadata */ modelMeta?: ModelMeta; diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 459028007..e16008299 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -253,7 +253,7 @@ export class PolicyUtil { /** * Injects model auth guard as where clause. */ - async injectAuthGuard(db: Record, args: any, model: string, operation: PolicyOperationKind) { + injectAuthGuard(db: Record, args: any, model: string, operation: PolicyOperationKind) { let guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { args.where = this.makeFalse(); @@ -277,14 +277,14 @@ export class PolicyUtil { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - await this.injectGuardForRelationFields(db, model, args.where, operation); + this.injectGuardForRelationFields(db, model, args.where, operation); } args.where = this.and(args.where, guard); return true; } - private async injectGuardForRelationFields( + private injectGuardForRelationFields( db: Record, model: string, payload: any, @@ -295,20 +295,20 @@ export class PolicyUtil { continue; } - const fieldInfo = await resolveField(this.modelMeta, model, field); + const fieldInfo = resolveField(this.modelMeta, model, field); if (!fieldInfo || !fieldInfo.isDataModel) { continue; } if (fieldInfo.isArray) { - await this.injectGuardForToManyField(db, fieldInfo, subPayload, operation); + this.injectGuardForToManyField(db, fieldInfo, subPayload, operation); } else { - await this.injectGuardForToOneField(db, fieldInfo, subPayload, operation); + this.injectGuardForToOneField(db, fieldInfo, subPayload, operation); } } } - private async injectGuardForToManyField( + private injectGuardForToManyField( db: Record, fieldInfo: FieldInfo, payload: { some?: any; every?: any; none?: any }, @@ -316,12 +316,12 @@ export class PolicyUtil { ) { const guard = this.getAuthGuard(db, fieldInfo.type, operation); if (payload.some) { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation); // turn "some" into: { some: { AND: [guard, payload.some] } } payload.some = this.and(payload.some, guard); } if (payload.none) { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation); // turn none into: { none: { AND: [guard, payload.none] } } payload.none = this.and(payload.none, guard); } @@ -331,7 +331,7 @@ export class PolicyUtil { // ignore empty every clause Object.keys(payload.every).length > 0 ) { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation); // turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } } if (!payload.none) { @@ -342,7 +342,7 @@ export class PolicyUtil { } } - private async injectGuardForToOneField( + private injectGuardForToOneField( db: Record, fieldInfo: FieldInfo, payload: { is?: any; isNot?: any } & Record, @@ -351,18 +351,18 @@ export class PolicyUtil { const guard = this.getAuthGuard(db, fieldInfo.type, operation); if (payload.is || payload.isNot) { if (payload.is) { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation); // turn "is" into: { is: { AND: [ originalIs, guard ] } payload.is = this.and(payload.is, guard); } if (payload.isNot) { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation); // turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } } payload.isNot = this.and(payload.isNot, this.not(guard)); delete payload.isNot; } } else { - await this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation); + this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation); // turn direct conditions into: { is: { AND: [ originalConditions, guard ] } } const combined = this.and(deepcopy(payload), guard); Object.keys(payload).forEach((key) => delete payload[key]); @@ -373,9 +373,9 @@ export class PolicyUtil { /** * Injects auth guard for read operations. */ - async injectForRead(db: Record, model: string, args: any) { + injectForRead(db: Record, model: string, args: any) { const injected: any = {}; - if (!(await this.injectAuthGuard(db, injected, model, 'read'))) { + if (!this.injectAuthGuard(db, injected, model, 'read')) { return false; } @@ -383,7 +383,7 @@ export class PolicyUtil { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - await this.injectGuardForRelationFields(db, model, args.where, 'read'); + this.injectGuardForRelationFields(db, model, args.where, 'read'); } if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) { @@ -395,7 +395,7 @@ export class PolicyUtil { } // recursively inject read guard conditions into nested select, include, and _count - const hoistedConditions = await this.injectNestedReadConditions(db, model, args); + const hoistedConditions = this.injectNestedReadConditions(db, model, args); // the injection process may generate conditions that need to be hoisted to the toplevel, // if so, merge it with the existing where @@ -441,7 +441,7 @@ export class PolicyUtil { /** * Builds a reversed query for the given nested path. */ - async buildReversedQuery(context: NestedWriteVisitorContext) { + buildReversedQuery(context: NestedWriteVisitorContext) { let result, currQuery: any; let currField: FieldInfo | undefined; @@ -489,11 +489,7 @@ export class PolicyUtil { return result; } - private async injectNestedReadConditions( - db: Record, - model: string, - args: any - ): Promise { + private injectNestedReadConditions(db: Record, model: string, args: any): any[] { const injectTarget = args.select ?? args.include; if (!injectTarget) { return []; @@ -526,7 +522,7 @@ export class PolicyUtil { continue; } // inject into the "where" clause inside select - await this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read'); + this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read'); } } @@ -552,10 +548,10 @@ export class PolicyUtil { injectTarget[field] = {}; } // inject extra condition for to-many or nullable to-one relation - await this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read'); + this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read'); // recurse - const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); + const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); if (subHoisted.length > 0) { // we can convert it to a where at this level injectTarget[field].where = this.and(injectTarget[field].where, ...subHoisted); @@ -564,7 +560,7 @@ export class PolicyUtil { // hoist non-nullable to-one filter to the parent level hoisted = this.getAuthGuard(db, fieldInfo.type, 'read'); // recurse - const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); + const subHoisted = this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); if (subHoisted.length > 0) { hoisted = this.and(hoisted, ...subHoisted); } @@ -732,7 +728,7 @@ export class PolicyUtil { CrudFailureReason.RESULT_NOT_READABLE ); - const injectResult = await this.injectForRead(db, model, readArgs); + const injectResult = this.injectForRead(db, model, readArgs); if (!injectResult) { return { error, result: undefined }; } @@ -1011,6 +1007,14 @@ export class PolicyUtil { } } + /** + * Gets information for all fields of a model. + */ + getModelFields(model: string) { + model = lowerCaseFirst(model); + return this.modelMeta.fields[model]; + } + /** * Gets information for a specific model field. */ diff --git a/packages/runtime/src/enhancements/policy/promise.ts b/packages/runtime/src/enhancements/policy/promise.ts new file mode 100644 index 000000000..b6d7baff9 --- /dev/null +++ b/packages/runtime/src/enhancements/policy/promise.ts @@ -0,0 +1,38 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +/** + * Creates a promise that only executes when it's awaited or .then() is called. + * @see https://github.com/prisma/prisma/blob/main/packages/client/src/runtime/core/request/createPrismaPromise.ts + */ +export function createDeferredPromise(callback: () => Promise): Promise { + let promise: Promise | undefined; + const cb = () => { + try { + return (promise ??= valueToPromise(callback())); + } catch (err) { + // deal with synchronous errors + return Promise.reject(err); + } + }; + + return { + then(onFulfilled, onRejected) { + return cb().then(onFulfilled, onRejected); + }, + catch(onRejected) { + return cb().catch(onRejected); + }, + finally(onFinally) { + return cb().finally(onFinally); + }, + [Symbol.toStringTag]: 'ZenStackPromise', + }; +} + +function valueToPromise(thing: any): Promise { + if (typeof thing === 'object' && typeof thing?.then === 'function') { + return thing; + } else { + return Promise.resolve(thing); + } +} diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index 717f63d2e..37593a6b6 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -1,7 +1,8 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { PRISMA_TX_FLAG, PRISMA_PROXY_ENHANCER } from '../constants'; +import { PRISMA_PROXY_ENHANCER, PRISMA_TX_FLAG } from '../constants'; import { DbClientContract } from '../types'; +import { createDeferredPromise } from './policy/promise'; import { ModelMeta } from './types'; /** @@ -174,11 +175,7 @@ export function makeProxy( modelMeta: ModelMeta, makeHandler: (prisma: object, model: string) => T, name = 'unnamed_enhancer' - // inTransaction = false ) { - // // put a transaction marker on the proxy target - // prisma[PRISIMA_TX_FLAG] = inTransaction; - const models = Object.keys(modelMeta.fields).map((k) => k.toLowerCase()); const proxy = new Proxy(prisma, { get: (target: any, prop: string | symbol, receiver: any) => { @@ -248,20 +245,39 @@ function createHandlerProxy(handler: T): T { // eslint-disable-next-line @typescript-eslint/ban-types const origMethod = prop as Function; - return async function (...args: any[]) { - // proxying async functions results in messed-up error stack trace, + return function (...args: any[]) { + // using proxy with async functions results in messed-up error stack trace, // create an error to capture the current stack const capture = new Error(ERROR_MARKER); - try { - return await origMethod.apply(handler, args); - } catch (err) { - if (capture.stack && err instanceof Error) { - // save the original stack and replace it with a clean one - (err as any).internalStack = err.stack; - err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message); + + // the original proxy returned by the PrismaClient proxy + const promise: Promise = origMethod.apply(handler, args); + + // modify the error stack + const resultPromise = createDeferredPromise(() => { + return new Promise((resolve, reject) => { + promise.then( + (value) => resolve(value), + (err) => { + if (capture.stack && err instanceof Error) { + // save the original stack and replace it with a clean one + (err as any).internalStack = err.stack; + err.stack = cleanCallStack(capture.stack, propKey.toString(), err.message); + } + reject(err); + } + ); + }); + }); + + // carry over extra fields from the original promise + for (const [k, v] of Object.entries(promise)) { + if (!(k in resultPromise)) { + (resultPromise as any)[k] = v; } - throw err; } + + return resultPromise; }; }, }); @@ -287,7 +303,7 @@ function cleanCallStack(stack: string, method: string, message: string) { } // skip leading zenstack and anonymous lines - if (line.includes('@zenstackhq/runtime') || line.includes('')) { + if (line.includes('@zenstackhq/runtime') || line.includes('Proxy.')) { continue; } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index 76366d87e..b09b5052c 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -1,25 +1,27 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ +export type PrismaPromise = Promise & Record PrismaPromise>; + /** * Weakly-typed database access methods */ export interface DbOperations { - findMany(args?: unknown): Promise; - findFirst(args: unknown): Promise; - findFirstOrThrow(args: unknown): Promise; - findUnique(args: unknown): Promise; - findUniqueOrThrow(args: unknown): Promise; - create(args: unknown): Promise; + findMany(args?: unknown): Promise; + findFirst(args?: unknown): PrismaPromise; + findFirstOrThrow(args?: unknown): PrismaPromise; + findUnique(args: unknown): PrismaPromise; + findUniqueOrThrow(args: unknown): PrismaPromise; + create(args: unknown): Promise; createMany(args: unknown, skipDuplicates?: boolean): Promise<{ count: number }>; - update(args: unknown): Promise; + update(args: unknown): Promise; updateMany(args: unknown): Promise<{ count: number }>; - upsert(args: unknown): Promise; - delete(args: unknown): Promise; + upsert(args: unknown): Promise; + delete(args: unknown): Promise; deleteMany(args?: unknown): Promise<{ count: number }>; - aggregate(args: unknown): Promise; - groupBy(args: unknown): Promise; - count(args?: unknown): Promise; - subscribe(args?: unknown): Promise; + aggregate(args: unknown): Promise; + groupBy(args: unknown): Promise; + count(args?: unknown): Promise; + subscribe(args?: unknown): Promise; fields: Record; } diff --git a/packages/schema/src/utils/version-utils.ts b/packages/schema/src/utils/version-utils.ts index 5ebc41bee..0e2de705d 100644 --- a/packages/schema/src/utils/version-utils.ts +++ b/packages/schema/src/utils/version-utils.ts @@ -3,7 +3,11 @@ export function getVersion() { try { return require('../package.json').version; } catch { - // dev environment - return require('../../package.json').version; + try { + // dev environment + return require('../../package.json').version; + } catch { + return undefined; + } } } diff --git a/packages/server/src/sveltekit/handler.ts b/packages/server/src/sveltekit/handler.ts index 2dbdf7e1d..f45eaf9db 100644 --- a/packages/server/src/sveltekit/handler.ts +++ b/packages/server/src/sveltekit/handler.ts @@ -36,7 +36,7 @@ export default function createHandler(options: HandlerOptions): Handle { } } - const requestHanler = options.handler ?? RPCApiHandler(); + const requestHandler = options.handler ?? RPCApiHandler(); if (options.useSuperJson !== undefined) { console.warn( 'The option "useSuperJson" is deprecated. The server APIs automatically use superjson for serialization.' @@ -67,7 +67,7 @@ export default function createHandler(options: HandlerOptions): Handle { const path = event.url.pathname.substring(options.prefix.length); try { - const r = await requestHanler({ + const r = await requestHandler({ method: event.request.method, path, query, diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index ae81d453e..42d30df04 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -24,18 +24,14 @@ import prismaPlugin from 'zenstack/plugins/prisma'; */ export const FILE_SPLITTER = '#FILE_SPLITTER#'; -export type WeakDbOperations = { - [key in keyof DbOperations]: (...args: any[]) => Promise; -}; - -export type WeakDbClientContract = Record & { +export type FullDbClientContract = Record & { $on(eventType: any, callback: (event: any) => void): void; $use(cb: any): void; $disconnect: () => Promise; - $transaction: (input: ((tx: WeakDbClientContract) => Promise) | any[], options?: any) => Promise; + $transaction: (input: ((tx: FullDbClientContract) => Promise) | any[], options?: any) => Promise; $queryRaw: (query: TemplateStringsArray, ...args: any[]) => Promise; $executeRaw: (query: TemplateStringsArray, ...args: any[]) => Promise; - $extends: (args: any) => WeakDbClientContract; + $extends: (args: any) => FullDbClientContract; }; export function run(cmd: string, env?: Record, cwd?: string) { @@ -245,15 +241,15 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { projectDir: projectRoot, prisma, withPolicy: (user?: AuthUser) => - withPolicy( + withPolicy( prisma, { user }, { policy, modelMeta, zodSchemas, logPrismaQuery: opt.logPrismaQuery } ), - withOmit: () => withOmit(prisma, { modelMeta }), - withPassword: () => withPassword(prisma, { modelMeta }), + withOmit: () => withOmit(prisma, { modelMeta }), + withPassword: () => withPassword(prisma, { modelMeta }), enhance: (user?: AuthUser) => - enhance( + enhance( prisma, { user }, { policy, modelMeta, zodSchemas, logPrismaQuery: opt.logPrismaQuery } diff --git a/tests/integration/tests/e2e/prisma-methods.test.ts b/tests/integration/tests/e2e/prisma-methods.test.ts index 1a2efde72..2053f0a73 100644 --- a/tests/integration/tests/e2e/prisma-methods.test.ts +++ b/tests/integration/tests/e2e/prisma-methods.test.ts @@ -1,9 +1,9 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { WeakDbClientContract, loadSchema, run } from '@zenstackhq/testtools'; +import { FullDbClientContract, loadSchema, run } from '@zenstackhq/testtools'; describe('Prisma Methods Tests', () => { - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { const { enhance, prisma: _prisma } = await loadSchema( diff --git a/tests/integration/tests/e2e/todo-presets.test.ts b/tests/integration/tests/e2e/todo-presets.test.ts index b454d8dda..dbd7f4003 100644 --- a/tests/integration/tests/e2e/todo-presets.test.ts +++ b/tests/integration/tests/e2e/todo-presets.test.ts @@ -1,11 +1,11 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { loadSchemaFromFile, run, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { loadSchemaFromFile, run, type FullDbClientContract } from '@zenstackhq/testtools'; import { compareSync } from 'bcryptjs'; import path from 'path'; describe('Todo Presets Tests', () => { - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { const { enhance, prisma: _prisma } = await loadSchemaFromFile(path.join(__dirname, '../schema/todo.zmodel'), { diff --git a/tests/integration/tests/e2e/type-coverage.test.ts b/tests/integration/tests/e2e/type-coverage.test.ts index 275f0d70c..c8c88211c 100644 --- a/tests/integration/tests/e2e/type-coverage.test.ts +++ b/tests/integration/tests/e2e/type-coverage.test.ts @@ -1,11 +1,11 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { loadSchema, run, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { loadSchema, run, type FullDbClientContract } from '@zenstackhq/testtools'; import Decimal from 'decimal.js'; import superjson from 'superjson'; describe('Type Coverage Tests', () => { - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { const { enhance, prisma: _prisma } = await loadSchema( diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index a2dfc1b86..0eed19f9d 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -12,7 +12,7 @@ describe('With Policy: auth() test', () => { process.chdir(origDir); }); - it('undefined user with string id', async () => { + it('undefined user with string id simple', async () => { const { withPolicy } = await loadSchema( ` model User { diff --git a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts index f2d2aa2ce..9608f9c62 100644 --- a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts +++ b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts @@ -1,4 +1,4 @@ -import { loadSchema, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { loadSchema, type FullDbClientContract } from '@zenstackhq/testtools'; import path from 'path'; describe('With Policy:deep nested', () => { @@ -60,8 +60,8 @@ describe('With Policy:deep nested', () => { } `; - let db: WeakDbClientContract; - let prisma: WeakDbClientContract; + let db: FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { origDir = path.resolve('.'); diff --git a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts index 4e3d8a4e8..f0f57ab25 100644 --- a/tests/integration/tests/enhancements/with-policy/field-validation.test.ts +++ b/tests/integration/tests/enhancements/with-policy/field-validation.test.ts @@ -1,7 +1,7 @@ -import { WeakDbClientContract, loadSchema, run } from '@zenstackhq/testtools'; +import { FullDbClientContract, loadSchema, run } from '@zenstackhq/testtools'; describe('With Policy: field validation', () => { - let db: WeakDbClientContract; + let db: FullDbClientContract; beforeAll(async () => { const { withPolicy, prisma: _prisma } = await loadSchema( diff --git a/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts new file mode 100644 index 000000000..264c5da28 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/fluent-api.test.ts @@ -0,0 +1,104 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('With Policy: fluent API', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(async () => { + process.chdir(origDir); + }); + + it('fluent api', async () => { + const { withPolicy, prisma } = await loadSchema( + ` +model User { + id Int @id + email String @unique + posts Post[] + + @@allow('all', true) +} + +model Post { + id Int @id + title String + author User? @relation(fields: [authorId], references: [id]) + authorId Int? + published Boolean @default(false) + secret String @default("secret") @allow('read', published == false) + + @@allow('all', author == auth()) +}` + ); + + await prisma.user.create({ + data: { + id: 1, + email: 'a@test.com', + posts: { + create: [ + { id: 1, title: 'post1', published: true }, + { id: 2, title: 'post2', published: false }, + ], + }, + }, + }); + + await prisma.user.create({ + data: { + id: 2, + email: 'b@test.com', + posts: { + create: [{ id: 3, title: 'post3' }], + }, + }, + }); + + const db = withPolicy({ id: 1 }); + + // check policies + await expect(db.user.findUnique({ where: { id: 1 } }).posts()).resolves.toHaveLength(2); + await expect( + db.user.findUnique({ where: { id: 1 } }).posts({ where: { published: true } }) + ).resolves.toHaveLength(1); + await expect(db.user.findUnique({ where: { id: 1 } }).posts({ take: 1 })).resolves.toHaveLength(1); + + // field-level policies + let p = (await db.user.findUnique({ where: { id: 1 } }).posts({ where: { published: true } }))[0]; + expect(p.secret).toBeUndefined(); + p = (await db.user.findUnique({ where: { id: 1 } }).posts({ where: { published: false } }))[0]; + expect(p.secret).toBeTruthy(); + + // to-one + await expect(db.post.findFirst({ where: { id: 1 } }).author()).resolves.toEqual( + expect.objectContaining({ id: 1, email: 'a@test.com' }) + ); + + // not-found + await expect(db.user.findUniqueOrThrow({ where: { id: 5 } }).posts()).toBeNotFound(); + await expect(db.user.findFirstOrThrow({ where: { id: 5 } }).posts()).toBeNotFound(); + await expect(db.post.findUniqueOrThrow({ where: { id: 5 } }).author()).toBeNotFound(); + await expect(db.post.findFirstOrThrow({ where: { id: 5 } }).author()).toBeNotFound(); + + // chaining + await expect( + db.post + .findFirst({ where: { id: 1 } }) + .author() + .posts() + ).resolves.toHaveLength(2); + + // chaining broken + expect((db.post.findMany() as any).author).toBeUndefined(); + expect( + db.post + .findFirst({ where: { id: 1 } }) + .author() + .posts().author + ).toBeUndefined(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts b/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts index 88b0b1f7d..9c251faf5 100644 --- a/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts +++ b/tests/integration/tests/enhancements/with-policy/petstore-sample.test.ts @@ -1,10 +1,10 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { loadSchemaFromFile, run, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { loadSchemaFromFile, run, type FullDbClientContract } from '@zenstackhq/testtools'; import path from 'path'; describe('Pet Store Policy Tests', () => { - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { const { withPolicy, prisma: _prisma } = await loadSchemaFromFile( diff --git a/tests/integration/tests/enhancements/with-policy/postgres.test.ts b/tests/integration/tests/enhancements/with-policy/postgres.test.ts index a6c389f92..caed6a5ce 100644 --- a/tests/integration/tests/enhancements/with-policy/postgres.test.ts +++ b/tests/integration/tests/enhancements/with-policy/postgres.test.ts @@ -1,5 +1,5 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { createPostgresDb, dropPostgresDb, loadSchemaFromFile, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { createPostgresDb, dropPostgresDb, loadSchemaFromFile, type FullDbClientContract } from '@zenstackhq/testtools'; import path from 'path'; const DB_NAME = 'todo-pg'; @@ -7,8 +7,8 @@ const DB_NAME = 'todo-pg'; describe('With Policy: with postgres', () => { let origDir: string; let dbUrl: string; - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { origDir = path.resolve('.'); @@ -483,7 +483,7 @@ const space2 = { slug: 'space2', }; -async function createSpaceAndUsers(db: WeakDbClientContract) { +async function createSpaceAndUsers(db: FullDbClientContract) { // create users await db.user.create({ data: user1 }); await db.user.create({ data: user2 }); diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index 4aca6ba88..adc2599ec 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -1,5 +1,5 @@ import { AuthUser, PrismaErrorCode } from '@zenstackhq/runtime'; -import { createPostgresDb, dropPostgresDb, loadSchemaFromFile, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { createPostgresDb, dropPostgresDb, loadSchemaFromFile, type FullDbClientContract } from '@zenstackhq/testtools'; import path from 'path'; const DB_NAME = 'refactor'; @@ -7,12 +7,12 @@ const DB_NAME = 'refactor'; describe('With Policy: refactor tests', () => { let origDir: string; let dbUrl: string; - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; - let anonDb: WeakDbClientContract; - let adminDb: WeakDbClientContract; - let user1Db: WeakDbClientContract; - let user2Db: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; + let anonDb: FullDbClientContract; + let adminDb: FullDbClientContract; + let user1Db: FullDbClientContract; + let user2Db: FullDbClientContract; beforeAll(async () => { origDir = path.resolve('.'); diff --git a/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts b/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts index 0f3305e0e..2b7dd416b 100644 --- a/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts +++ b/tests/integration/tests/enhancements/with-policy/todo-sample.test.ts @@ -1,10 +1,10 @@ import { AuthUser } from '@zenstackhq/runtime'; -import { loadSchemaFromFile, run, type WeakDbClientContract } from '@zenstackhq/testtools'; +import { loadSchemaFromFile, run, type FullDbClientContract } from '@zenstackhq/testtools'; import path from 'path'; describe('Todo Policy Tests', () => { - let getDb: (user?: AuthUser) => WeakDbClientContract; - let prisma: WeakDbClientContract; + let getDb: (user?: AuthUser) => FullDbClientContract; + let prisma: FullDbClientContract; beforeAll(async () => { const { withPolicy, prisma: _prisma } = await loadSchemaFromFile( @@ -468,7 +468,7 @@ const space2 = { slug: 'space2', }; -async function createSpaceAndUsers(db: WeakDbClientContract) { +async function createSpaceAndUsers(db: FullDbClientContract) { // create users await db.user.create({ data: user1 }); await db.user.create({ data: user2 }); diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index 15626e1c2..99179e015 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; import path from 'path'; -describe('With Policy:toplevel operations', () => { +describe('With Policy: toplevel operations', () => { let origDir: string; beforeAll(async () => {