diff --git a/.eslintrc.json b/.eslintrc.json index e04b04831..707715244 100644 --- a/.eslintrc.json +++ b/.eslintrc.json @@ -13,7 +13,7 @@ "plugin:jest/recommended" ], "rules": { - "jest/expect-expect": "off", - "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }] + "@typescript-eslint/no-unused-vars": ["error", { "varsIgnorePattern": "^_", "argsIgnorePattern": "^_" }], + "jest/expect-expect": "off" } } diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index 86dd55bed..3da706a75 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -52,6 +52,17 @@ declare module './generated/ast' { interface DataModelAttribute { $inheritedFrom?: DataModel; } + + export interface DataModel { + /** + * Indicates whether the model is already merged with the base types + */ + $baseMerged?: boolean; + } +} + +export interface InheritableNode extends AstNode { + $inheritedFrom?: DataModel; } export interface InheritableNode extends AstNode { diff --git a/packages/plugins/swr/tests/test-model-meta.ts b/packages/plugins/swr/tests/test-model-meta.ts index 41731ad18..71a657bad 100644 --- a/packages/plugins/swr/tests/test-model-meta.ts +++ b/packages/plugins/swr/tests/test-model-meta.ts @@ -11,39 +11,46 @@ const fieldDefaults = { }; export const modelMeta: ModelMeta = { - fields: { + models: { user: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, - }, - name: { ...fieldDefaults, type: 'String', name: 'name' }, - email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, - posts: { - ...fieldDefaults, - type: 'Post', - isDataModel: true, - isArray: true, - name: 'posts', + name: 'user', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + name: { ...fieldDefaults, type: 'String', name: 'name' }, + email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, + posts: { + ...fieldDefaults, + type: 'Post', + isDataModel: true, + isArray: true, + name: 'posts', + }, }, + uniqueConstraints: {}, }, post: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, + name: 'post', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + title: { ...fieldDefaults, type: 'String', name: 'title' }, + owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, + ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - title: { ...fieldDefaults, type: 'String', name: 'title' }, - owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, - ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, + uniqueConstraints: {}, }, }, - uniqueConstraints: {}, deleteCascade: { user: ['Post'], }, diff --git a/packages/plugins/tanstack-query/tests/test-model-meta.ts b/packages/plugins/tanstack-query/tests/test-model-meta.ts index 41731ad18..71a657bad 100644 --- a/packages/plugins/tanstack-query/tests/test-model-meta.ts +++ b/packages/plugins/tanstack-query/tests/test-model-meta.ts @@ -11,39 +11,46 @@ const fieldDefaults = { }; export const modelMeta: ModelMeta = { - fields: { + models: { user: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, - }, - name: { ...fieldDefaults, type: 'String', name: 'name' }, - email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, - posts: { - ...fieldDefaults, - type: 'Post', - isDataModel: true, - isArray: true, - name: 'posts', + name: 'user', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + name: { ...fieldDefaults, type: 'String', name: 'name' }, + email: { ...fieldDefaults, type: 'String', name: 'name', isOptional: false }, + posts: { + ...fieldDefaults, + type: 'Post', + isDataModel: true, + isArray: true, + name: 'posts', + }, }, + uniqueConstraints: {}, }, post: { - id: { - ...fieldDefaults, - type: 'String', - isId: true, - name: 'id', - isOptional: false, + name: 'post', + fields: { + id: { + ...fieldDefaults, + type: 'String', + isId: true, + name: 'id', + isOptional: false, + }, + title: { ...fieldDefaults, type: 'String', name: 'title' }, + owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, + ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, }, - title: { ...fieldDefaults, type: 'String', name: 'title' }, - owner: { ...fieldDefaults, type: 'User', name: 'owner', isDataModel: true, isRelationOwner: true }, - ownerId: { ...fieldDefaults, type: 'User', name: 'owner', isForeignKey: true }, + uniqueConstraints: {}, }, }, - uniqueConstraints: {}, deleteCascade: { user: ['Post'], }, diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 5ae027701..1d5c8fd37 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -56,8 +56,10 @@ "bcryptjs": "^2.4.3", "buffer": "^6.0.3", "change-case": "^4.1.2", + "colors": "1.4.0", "decimal.js": "^10.4.2", "deepcopy": "^2.1.0", + "deepmerge": "^4.3.1", "lower-case-first": "^2.0.2", "pluralize": "^8.0.0", "semver": "^7.5.2", diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index c381a5a88..a85392887 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -97,3 +97,8 @@ export const FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX = 'updateFieldGuardOverrid * Flag that indicates if the model has field-level access control */ export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy'; + +/** + * Prefix for auxiliary relation field generated for delegated models + */ +export const DELEGATE_AUX_RELATION_PREFIX = 'delegate_aux'; diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index 401caeaf2..9f767af0e 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -4,7 +4,14 @@ import { lowerCaseFirst } from 'lower-case-first'; * Runtime information of a data model or field attribute */ export type RuntimeAttribute = { + /** + * Attribute name + */ name: string; + + /** + * Attribute arguments + */ args: Array<{ name?: string; value: unknown }>; }; @@ -72,6 +79,11 @@ export type FieldInfo = { */ foreignKeyMapping?: Record; + /** + * Model from which the field is inherited + */ + inheritedFrom?: string; + /** * A function that provides a default value for the field */ @@ -90,23 +102,53 @@ export type FieldInfo = { export type UniqueConstraint = { name: string; fields: string[] }; /** - * ZModel data model metadata + * Metadata for a data model */ -export type ModelMeta = { +export type ModelInfo = { + /** + * Model name + */ + name: string; + + /** + * Base types + */ + baseTypes?: string[]; + + /** + * Fields + */ + fields: Record; + + /** + * Unique constraints + */ + uniqueConstraints?: Record; + + /** + * Attributes on the model + */ + attributes?: RuntimeAttribute[]; + /** - * Model fields + * Discriminator field name */ - fields: Record>; + discriminator?: string; +}; +/** + * ZModel data model metadata + */ +export type ModelMeta = { /** - * Model unique constraints + * Data models */ - uniqueConstraints: Record>; + models: Record; /** - * Information for cascading delete + * Mapping from model name to models that will be deleted because of it due to cascade delete */ - deleteCascade: Record; + deleteCascade?: Record; /** * Name of model that backs the `auth()` function @@ -117,8 +159,8 @@ export type ModelMeta = { /** * Resolves a model field to its metadata. Returns undefined if not found. */ -export function resolveField(modelMeta: ModelMeta, model: string, field: string) { - return modelMeta.fields[lowerCaseFirst(model)]?.[field]; +export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined { + return modelMeta.models[lowerCaseFirst(model)]?.fields?.[field]; } /** @@ -136,5 +178,12 @@ export function requireField(modelMeta: ModelMeta, model: string, field: string) * Gets all fields of a model. */ export function getFields(modelMeta: ModelMeta, model: string) { - return modelMeta.fields[lowerCaseFirst(model)]; + return modelMeta.models[lowerCaseFirst(model)]?.fields; +} + +/** + * Gets unique constraints of a model. + */ +export function getUniqueConstraints(modelMeta: ModelMeta, model: string) { + return modelMeta.models[lowerCaseFirst(model)]?.uniqueConstraints; } diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index 477117dbd..db2455d7e 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -219,8 +219,10 @@ export class NestedWriteVisitor { case 'set': if (this.callback.set) { - const newContext = pushNewContext(field, model, {}); - await this.callback.set(model, data, newContext); + for (const item of enumerate(data)) { + const newContext = pushNewContext(field, model, item, true); + await this.callback.set(model, item, newContext); + } } break; diff --git a/packages/runtime/src/cross/query-analyzer.ts b/packages/runtime/src/cross/query-analyzer.ts index 5af410e82..bf501f020 100644 --- a/packages/runtime/src/cross/query-analyzer.ts +++ b/packages/runtime/src/cross/query-analyzer.ts @@ -81,7 +81,7 @@ function collectDeleteCascades(model: string, modelMeta: ModelMeta, result: Set< } visited.add(model); - const cascades = modelMeta.deleteCascade[lowerCaseFirst(model)]; + const cascades = modelMeta.deleteCascade?.[lowerCaseFirst(model)]; if (!cascades) { return; diff --git a/packages/runtime/src/cross/utils.ts b/packages/runtime/src/cross/utils.ts index e4237dbc7..1982513b3 100644 --- a/packages/runtime/src/cross/utils.ts +++ b/packages/runtime/src/cross/utils.ts @@ -1,5 +1,5 @@ import { lowerCaseFirst } from 'lower-case-first'; -import { ModelMeta } from '.'; +import { ModelInfo, ModelMeta } from '.'; /** * Gets field names in a data model entity, filtering out internal fields. @@ -47,7 +47,7 @@ export function zip(x: Enumerable, y: Enumerable): Array<[T1, T2 } export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound = false) { - let fields = modelMeta.fields[lowerCaseFirst(model)]; + let fields = modelMeta.models[lowerCaseFirst(model)]?.fields; if (!fields) { if (throwIfNotFound) { throw new Error(`Unable to load fields for ${model}`); @@ -61,3 +61,19 @@ export function getIdFields(modelMeta: ModelMeta, model: string, throwIfNotFound } return result; } + +export function getModelInfo( + modelMeta: ModelMeta, + model: string, + throwIfNotFound: Throw = false as Throw +): Throw extends true ? ModelInfo : ModelInfo | undefined { + const info = modelMeta.models[lowerCaseFirst(model)]; + if (!info && throwIfNotFound) { + throw new Error(`Unable to load info for ${model}`); + } + return info; +} + +export function isDelegateModel(modelMeta: ModelMeta, model: string) { + return !!getModelInfo(modelMeta, model)?.attributes?.some((attr) => attr.name === '@@delegate'); +} diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index b137e03f9..dbca40874 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -1,8 +1,11 @@ +import colors from 'colors'; import semver from 'semver'; import { PRISMA_MINIMUM_VERSION } from '../constants'; -import { ModelMeta } from '../cross'; +import { isDelegateModel, type ModelMeta } from '../cross'; import type { AuthUser } from '../types'; import { withDefaultAuth } from './default-auth'; +import { withDelegate } from './delegate'; +import { Logger } from './logger'; import { withOmit } from './omit'; import { withPassword } from './password'; import { withPolicy } from './policy'; @@ -12,12 +15,12 @@ import type { PolicyDef, ZodSchemas } from './types'; /** * Kinds of enhancements to `PrismaClient` */ -export enum EnhancementKind { - Password = 'password', - Omit = 'omit', - Policy = 'policy', - DefaultAuth = 'defaultAuth', -} +export type EnhancementKind = 'password' | 'omit' | 'policy' | 'delegate'; + +/** + * All enhancement kinds + */ +const ALL_ENHANCEMENTS = ['password', 'omit', 'policy', 'delegate']; /** * Transaction isolation levels: https://www.prisma.io/docs/orm/prisma-client/queries/transactions#transaction-isolation-level @@ -121,6 +124,9 @@ export function createEnhancement( ); } + const logger = new Logger(prisma); + logger.info(`Enabled ZenStack enhancements: ${options.kinds?.join(', ')}`); + let result = prisma; if ( @@ -129,38 +135,48 @@ export function createEnhancement( hasOmit === undefined || hasDefaultAuth === undefined ) { - const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); + const allFields = Object.values(options.modelMeta.models).flatMap((modelInfo) => + Object.values(modelInfo.fields) + ); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); } - const kinds = options.kinds ?? [ - EnhancementKind.Password, - EnhancementKind.Omit, - EnhancementKind.Policy, - EnhancementKind.DefaultAuth, - ]; + const kinds = options.kinds ?? ALL_ENHANCEMENTS; + + // delegate proxy needs to be wrapped inside policy proxy, since it may translate `deleteMany` + // and `updateMany` to plain `delete` and `update` + if (Object.values(options.modelMeta.models).some((model) => isDelegateModel(options.modelMeta, model.name))) { + if (!kinds.includes('delegate')) { + console.warn( + colors.yellow( + 'Your ZModel contains delegate models but "delegate" enhancement kind is not enabled. This may result in unexpected behavior.' + ) + ); + } else { + result = withDelegate(result, options); + } + } - if (hasPassword && kinds.includes(EnhancementKind.Password)) { + // policy proxy + if (kinds.includes('policy')) { + result = withPolicy(result, options, context); + if (hasDefaultAuth) { + // @default(auth()) proxy + result = withDefaultAuth(result, options, context); + } + } + + if (hasPassword && kinds.includes('password')) { // @password proxy result = withPassword(result, options); } - if (hasOmit && kinds.includes(EnhancementKind.Omit)) { + if (hasOmit && kinds.includes('omit')) { // @omit proxy result = withOmit(result, options); } - if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) { - // @default(auth()) proxy - result = withDefaultAuth(result, options, context); - } - - // policy proxy - if (kinds.includes(EnhancementKind.Policy)) { - result = withPolicy(result, options, context); - } - return result; } diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts index 48af0ed73..9e0a64a4f 100644 --- a/packages/runtime/src/enhancements/default-auth.ts +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -26,17 +26,15 @@ export function withDefaultAuth( } class DefaultAuthHandler extends DefaultPrismaProxyHandler { - private readonly db: DbClientContract; private readonly userContext: any; constructor( prisma: DbClientContract, model: string, - private readonly options: EnhancementOptions, + options: EnhancementOptions, private readonly context?: EnhancementContext ) { - super(prisma, model); - this.db = prisma; + super(prisma, model, options); if (!this.context?.user) { throw new Error(`Using \`auth()\` in \`@default\` requires a user context`); diff --git a/packages/runtime/src/enhancements/delegate.ts b/packages/runtime/src/enhancements/delegate.ts new file mode 100644 index 000000000..0a1e39d8c --- /dev/null +++ b/packages/runtime/src/enhancements/delegate.ts @@ -0,0 +1,1133 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import deepcopy from 'deepcopy'; +import deepmerge from 'deepmerge'; +import { lowerCaseFirst } from 'lower-case-first'; +import { DELEGATE_AUX_RELATION_PREFIX } from '../constants'; +import { + FieldInfo, + ModelInfo, + NestedWriteVisitor, + enumerate, + getIdFields, + getModelInfo, + isDelegateModel, + requireField, + resolveField, +} from '../cross'; +import type { CrudContract, DbClientContract } from '../types'; +import type { EnhancementOptions } from './create-enhancement'; +import { Logger } from './logger'; +import { DefaultPrismaProxyHandler, makeProxy } from './proxy'; +import { QueryUtils } from './query-utils'; +import { formatObject, prismaClientValidationError } from './utils'; + +export function withDelegate(prisma: DbClient, options: EnhancementOptions): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new DelegateProxyHandler(_prisma as DbClientContract, model, options), + 'delegate' + ); +} + +export class DelegateProxyHandler extends DefaultPrismaProxyHandler { + private readonly logger: Logger; + private readonly queryUtils: QueryUtils; + + constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + super(prisma, model, options); + this.logger = new Logger(prisma); + this.queryUtils = new QueryUtils(prisma, this.options); + } + + // #region find + + override findFirst(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findFirst', args); + } + + override findFirstOrThrow(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findFirstOrThrow', args); + } + + override findUnique(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findUnique', args); + } + + override findUniqueOrThrow(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findUniqueOrThrow', args); + } + + override async findMany(args: any): Promise { + return this.doFind(this.prisma, this.model, 'findMany', args); + } + + private async doFind( + db: CrudContract, + model: string, + method: 'findFirst' | 'findFirstOrThrow' | 'findUnique' | 'findUniqueOrThrow' | 'findMany', + args: any + ) { + if (!this.involvesDelegateModel(model)) { + return super[method](args); + } + + args = args ? deepcopy(args) : {}; + + this.injectWhereHierarchy(model, args?.where); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`${method}\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const entity = await db[model][method](args); + + if (Array.isArray(entity)) { + return entity.map((item) => this.assembleHierarchy(model, item)); + } else { + return this.assembleHierarchy(model, entity); + } + } + + private injectWhereHierarchy(model: string, where: any) { + if (!where || typeof where !== 'object') { + return; + } + + Object.entries(where).forEach(([field, value]) => { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo?.inheritedFrom) { + return; + } + + let base = this.getBaseModel(model); + let target = where; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer where + let thisLayer: any; + if (target[baseRelationName]) { + thisLayer = target[baseRelationName]; + } else { + thisLayer = target[baseRelationName] = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + thisLayer[field] = value; + delete where[field]; + break; + } else { + target = thisLayer; + base = this.getBaseModel(base.name); + } + } + }); + } + + private buildWhereHierarchy(where: any) { + if (!where) { + return undefined; + } + + where = deepcopy(where); + Object.entries(where).forEach(([field, value]) => { + const fieldInfo = resolveField(this.options.modelMeta, this.model, field); + if (!fieldInfo?.inheritedFrom) { + return; + } + + let base = this.getBaseModel(this.model); + let target = where; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer where + let thisLayer: any; + if (target[baseRelationName]) { + thisLayer = target[baseRelationName]; + } else { + thisLayer = target[baseRelationName] = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + thisLayer[field] = value; + delete where[field]; + break; + } else { + target = thisLayer; + base = this.getBaseModel(base.name); + } + } + }); + + return where; + } + + private injectSelectIncludeHierarchy(model: string, args: any) { + if (!args || typeof args !== 'object') { + return; + } + + for (const kind of ['select', 'include'] as const) { + if (args[kind] && typeof args[kind] === 'object') { + for (const [field, value] of Object.entries(args[kind])) { + if (value !== undefined) { + if (this.injectBaseFieldSelect(model, field, value, args, kind)) { + delete args[kind][field]; + } else { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo && this.isDelegateOrDescendantOfDelegate(fieldInfo.type)) { + let nextValue = value; + if (nextValue === true) { + // make sure the payload is an object + args[kind][field] = nextValue = {}; + } + this.injectSelectIncludeHierarchy(fieldInfo.type, nextValue); + } + } + } + } + } + } + + if (!args.select) { + this.injectBaseIncludeRecursively(model, args); + } + } + + private buildSelectIncludeHierarchy(model: string, args: any) { + args = deepcopy(args); + const selectInclude: any = this.extractSelectInclude(args) || {}; + + if (selectInclude.select && typeof selectInclude.select === 'object') { + Object.entries(selectInclude.select).forEach(([field, value]) => { + if (value) { + if (this.injectBaseFieldSelect(model, field, value, selectInclude, 'select')) { + delete selectInclude.select[field]; + } + } + }); + } else if (selectInclude.include && typeof selectInclude.include === 'object') { + Object.entries(selectInclude.include).forEach(([field, value]) => { + if (value) { + if (this.injectBaseFieldSelect(model, field, value, selectInclude, 'include')) { + delete selectInclude.include[field]; + } + } + }); + } + + if (!selectInclude.select) { + this.injectBaseIncludeRecursively(model, selectInclude); + } + return selectInclude; + } + + private injectBaseFieldSelect( + model: string, + field: string, + value: any, + selectInclude: any, + context: 'select' | 'include' + ) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (!fieldInfo?.inheritedFrom) { + return false; + } + + let base = this.getBaseModel(model); + let target = selectInclude; + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + // prepare base layer select/include + // let selectOrInclude = 'select'; + let thisLayer: any; + if (target.include) { + // selectOrInclude = 'include'; + thisLayer = target.include; + } else if (target.select) { + // selectOrInclude = 'select'; + thisLayer = target.select; + } else { + // selectInclude = 'include'; + thisLayer = target.select = {}; + } + + if (base.name === fieldInfo.inheritedFrom) { + if (!thisLayer[baseRelationName]) { + thisLayer[baseRelationName] = { [context]: {} }; + } + thisLayer[baseRelationName][context][field] = value; + break; + } else { + if (!thisLayer[baseRelationName]) { + thisLayer[baseRelationName] = { select: {} }; + } + target = thisLayer[baseRelationName]; + base = this.getBaseModel(base.name); + } + } + + return true; + } + + private injectBaseIncludeRecursively(model: string, selectInclude: any) { + const base = this.getBaseModel(model); + if (!base) { + return; + } + const baseRelationName = this.makeAuxRelationName(base); + + if (selectInclude.select) { + selectInclude.include = { [baseRelationName]: {}, ...selectInclude.select }; + delete selectInclude.select; + } else { + selectInclude.include = { [baseRelationName]: {}, ...selectInclude.include }; + } + this.injectBaseIncludeRecursively(base.name, selectInclude.include[baseRelationName]); + } + + // #endregion + + // #region create + + override async create(args: any) { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (isDelegateModel(this.options.modelMeta, this.model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${this.model}" is a delegate and cannot be created directly` + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.create(args); + } + + return this.doCreate(this.prisma, this.model, args); + } + + override createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.createMany(args); + } + + if (this.isDelegateOrDescendantOfDelegate(this.model) && args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + // note that we can't call `createMany` directly because it doesn't support + // nested created, which is needed for creating base entities + return this.queryUtils.transaction(this.prisma, async (tx) => { + const r = await Promise.all( + enumerate(args.data).map(async (item) => { + return this.doCreate(tx, this.model, item); + }) + ); + + // filter out undefined value (due to skipping duplicates) + return { count: r.filter((item) => !!item).length }; + }); + } + + private async doCreate(db: CrudContract, model: string, args: any) { + args = deepcopy(args); + + await this.injectCreateHierarchy(model, args); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`create\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].create(args); + return this.assembleHierarchy(model, result); + } + + private async injectCreateHierarchy(model: string, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + create: (model, args, _context) => { + this.doProcessCreatePayload(model, args); + }, + + createMany: (model, args, _context) => { + if (args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + for (const item of enumerate(args?.data)) { + this.doProcessCreatePayload(model, item); + } + }, + }); + + await visitor.visit(model, 'create', args); + } + + private doProcessCreatePayload(model: string, args: any) { + if (!args) { + return; + } + + this.ensureBaseCreateHierarchy(model, args); + + for (const [field, value] of Object.entries(args)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo?.inheritedFrom) { + this.injectBaseFieldData(model, fieldInfo, value, args, 'create'); + delete args[field]; + } + } + } + + // ensure the full nested "create" structure is created for base types + private ensureBaseCreateHierarchy(model: string, result: any) { + let curr = result; + let base = this.getBaseModel(model); + let sub = this.getModelInfo(model); + + while (base) { + const baseRelationName = this.makeAuxRelationName(base); + + if (!curr[baseRelationName]) { + curr[baseRelationName] = {}; + } + if (!curr[baseRelationName].create) { + curr[baseRelationName].create = {}; + if (base.discriminator) { + // set discriminator field + curr[baseRelationName].create[base.discriminator] = sub.name; + } + } + curr = curr[baseRelationName].create; + sub = base; + base = this.getBaseModel(base.name); + } + } + + // inject field data that belongs to base type into proper nesting structure + private injectBaseFieldData( + model: string, + fieldInfo: FieldInfo, + value: unknown, + args: any, + mode: 'create' | 'update' + ) { + let base = this.getBaseModel(model); + let curr = args; + + while (base) { + if (base.discriminator === fieldInfo.name) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `fields "${fieldInfo.name}" is a discriminator and cannot be set directly` + ); + } + + const baseRelationName = this.makeAuxRelationName(base); + + if (!curr[baseRelationName]) { + curr[baseRelationName] = {}; + } + if (!curr[baseRelationName][mode]) { + curr[baseRelationName][mode] = {}; + } + curr = curr[baseRelationName][mode]; + + if (fieldInfo.inheritedFrom === base.name) { + curr[fieldInfo.name] = value; + break; + } + + base = this.getBaseModel(base.name); + } + } + + // #endregion + + // #region update + + override update(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.update(args); + } + + return this.queryUtils.transaction(this.prisma, (tx) => this.doUpdate(tx, this.model, args)); + } + + override async updateMany(args: any): Promise<{ count: number }> { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.data) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'data field is required in query argument' + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.updateMany(args); + } + + const simpleUpdateMany = Object.keys(args.data).every((key) => { + // check if the `data` clause involves base fields + const fieldInfo = resolveField(this.options.modelMeta, this.model, key); + return !fieldInfo?.inheritedFrom; + }); + + return this.queryUtils.transaction(this.prisma, (tx) => + this.doUpdateMany(tx, this.model, args, simpleUpdateMany) + ); + } + + override async upsert(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!args.where) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + 'where field is required in query argument' + ); + } + + if (isDelegateModel(this.options.modelMeta, this.model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${this.model}" is a delegate and doesn't support upsert` + ); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.upsert(args); + } + + args = deepcopy(args); + this.injectWhereHierarchy(this.model, (args as any)?.where); + this.injectSelectIncludeHierarchy(this.model, args); + if (args.create) { + this.doProcessCreatePayload(this.model, args.create); + } + if (args.update) { + this.doProcessUpdatePayload(this.model, args.update); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`upsert\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + const result = await this.prisma[this.model].upsert(args); + return this.assembleHierarchy(this.model, result); + } + + private async doUpdate(db: CrudContract, model: string, args: any): Promise { + args = deepcopy(args); + + await this.injectUpdateHierarchy(db, model, args); + this.injectSelectIncludeHierarchy(model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`update\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].update(args); + return this.assembleHierarchy(model, result); + } + + private async doUpdateMany( + db: CrudContract, + model: string, + args: any, + simpleUpdateMany: boolean + ): Promise<{ count: number }> { + if (simpleUpdateMany) { + // do a direct `updateMany` + args = deepcopy(args); + await this.injectUpdateHierarchy(db, model, args); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`updateMany\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + return db[model].updateMany(args); + } else { + // translate to plain `update` for nested write into base fields + const findArgs = { + where: deepcopy(args.where), + select: this.queryUtils.makeIdSelection(model), + }; + await this.injectUpdateHierarchy(db, model, findArgs); + if (this.options.logPrismaQuery) { + this.logger.info( + `[delegate] \`updateMany\` find candidates: ${this.getModelName(model)}: ${formatObject(findArgs)}` + ); + } + const entities = await db[model].findMany(findArgs); + + const updatePayload = { data: deepcopy(args.data), select: this.queryUtils.makeIdSelection(model) }; + await this.injectUpdateHierarchy(db, model, updatePayload); + const result = await Promise.all( + entities.map((entity) => { + const updateArgs = { + where: entity, + ...updatePayload, + }; + this.logger.info( + `[delegate] \`updateMany\` update: ${this.getModelName(model)}: ${formatObject(updateArgs)}` + ); + return db[model].update(updateArgs); + }) + ); + return { count: result.length }; + } + } + + private async injectUpdateHierarchy(db: CrudContract, model: string, args: any) { + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + update: (model, args, _context) => { + this.injectWhereHierarchy(model, (args as any)?.where); + this.doProcessUpdatePayload(model, (args as any)?.data); + }, + + updateMany: async (model, args, context) => { + let simpleUpdateMany = Object.keys(args.data).every((key) => { + // check if the `data` clause involves base fields + const fieldInfo = resolveField(this.options.modelMeta, model, key); + return !fieldInfo?.inheritedFrom; + }); + + if (simpleUpdateMany) { + // check if the `where` clause involves base fields + simpleUpdateMany = Object.keys(args.where || {}).every((key) => { + const fieldInfo = resolveField(this.options.modelMeta, model, key); + return !fieldInfo?.inheritedFrom; + }); + } + + if (simpleUpdateMany) { + this.injectWhereHierarchy(model, (args as any)?.where); + this.doProcessUpdatePayload(model, (args as any)?.data); + } else { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doUpdateMany(tx, model, { ...args, where }, simpleUpdateMany); + }); + delete context.parent['updateMany']; + } + }, + + upsert: (model, args, _context) => { + this.injectWhereHierarchy(model, (args as any)?.where); + if (args.create) { + this.doProcessCreatePayload(model, (args as any)?.create); + } + if (args.update) { + this.doProcessUpdatePayload(model, (args as any)?.update); + } + }, + + create: (model, args, _context) => { + if (isDelegateModel(this.options.modelMeta, model)) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `Model "${model}" is a delegate and cannot be created directly` + ); + } + this.doProcessCreatePayload(model, args); + }, + + createMany: (model, args, _context) => { + if (args.skipDuplicates) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + '`createMany` with `skipDuplicates` set to true is not supported for delegated models' + ); + } + + for (const item of enumerate(args?.data)) { + this.doProcessCreatePayload(model, item); + } + }, + + connect: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + connectOrCreate: (model, args, _context) => { + this.injectWhereHierarchy(model, args.where); + if (args.create) { + this.doProcessCreatePayload(model, args.create); + } + }, + + disconnect: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + set: (model, args, _context) => { + this.injectWhereHierarchy(model, args); + }, + + delete: async (model, _args, context) => { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doDelete(tx, model, { where }); + }); + delete context.parent['delete']; + }, + + deleteMany: async (model, _args, context) => { + const where = this.queryUtils.buildReversedQuery(context, false, false); + await this.queryUtils.transaction(db, async (tx) => { + await this.doDeleteMany(tx, model, where); + }); + delete context.parent['deleteMany']; + }, + }); + + await visitor.visit(model, 'update', args); + } + + private doProcessUpdatePayload(model: string, data: any) { + if (!data) { + return; + } + + for (const [field, value] of Object.entries(data)) { + const fieldInfo = resolveField(this.options.modelMeta, model, field); + if (fieldInfo?.inheritedFrom) { + this.injectBaseFieldData(model, fieldInfo, value, data, 'update'); + delete data[field]; + } + } + } + + // #endregion + + // #region delete + + override delete(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + + if (!this.involvesDelegateModel(this.model)) { + return super.delete(args); + } + + return this.queryUtils.transaction(this.prisma, async (tx) => { + const selectInclude = this.buildSelectIncludeHierarchy(this.model, args); + + // make sure id fields are selected + const idFields = this.getIdFields(this.model); + for (const idField of idFields) { + if (selectInclude?.select && !(idField.name in selectInclude.select)) { + selectInclude.select[idField.name] = true; + } + } + + const deleteArgs = { ...deepcopy(args), ...selectInclude }; + return this.doDelete(tx, this.model, deleteArgs); + }); + } + + override deleteMany(args: any): Promise<{ count: number }> { + if (!this.involvesDelegateModel(this.model)) { + return super.deleteMany(args); + } + + return this.queryUtils.transaction(this.prisma, (tx) => this.doDeleteMany(tx, this.model, args?.where)); + } + + private async doDeleteMany(db: CrudContract, model: string, where: any): Promise<{ count: number }> { + // query existing entities with id + const idSelection = this.queryUtils.makeIdSelection(model); + const findArgs = { where: deepcopy(where), select: idSelection }; + this.injectWhereHierarchy(model, findArgs.where); + + if (this.options.logPrismaQuery) { + this.logger.info( + `[delegate] \`deleteMany\` find candidates: ${this.getModelName(model)}: ${formatObject(findArgs)}` + ); + } + const entities = await db[model].findMany(findArgs); + + // recursively delete base entities (they all have the same id values) + await Promise.all(entities.map((entity) => this.doDelete(db, model, { where: entity }))); + + return { count: entities.length }; + } + + private async deleteBaseRecursively(db: CrudContract, model: string, idValues: any) { + let base = this.getBaseModel(model); + while (base) { + await db[base.name].delete({ where: idValues }); + base = this.getBaseModel(base.name); + } + } + + private async doDelete(db: CrudContract, model: string, args: any): Promise { + this.injectWhereHierarchy(model, args.where); + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`delete\` ${this.getModelName(model)}: ${formatObject(args)}`); + } + const result = await db[model].delete(args); + const idValues = this.queryUtils.getEntityIds(model, result); + + // recursively delete base entities (they all have the same id values) + await this.deleteBaseRecursively(db, model, idValues); + return this.assembleHierarchy(model, result); + } + + // #endregion + + // #region aggregation + + override aggregate(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!this.involvesDelegateModel(this.model)) { + return super.aggregate(args); + } + + // check if any aggregation operator is using fields from base + this.checkAggregationArgs('aggregate', args); + + args = deepcopy(args); + + if (args.cursor) { + args.cursor = this.buildWhereHierarchy(args.cursor); + } + + if (args.orderBy) { + args.orderBy = this.buildWhereHierarchy(args.orderBy); + } + + if (args.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`aggregate\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.aggregate(args); + } + + override count(args: any): Promise { + if (!this.involvesDelegateModel(this.model)) { + return super.count(args); + } + + // check if count select is using fields from base + this.checkAggregationArgs('count', args); + + args = deepcopy(args); + + if (args?.cursor) { + args.cursor = this.buildWhereHierarchy(args.cursor); + } + + if (args?.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`count\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.count(args); + } + + override groupBy(args: any): Promise { + if (!args) { + throw prismaClientValidationError(this.prisma, this.options.prismaModule, 'query argument is required'); + } + if (!this.involvesDelegateModel(this.model)) { + return super.groupBy(args); + } + + // check if count select is using fields from base + this.checkAggregationArgs('groupBy', args); + + if (args.by) { + for (const by of enumerate(args.by)) { + const fieldInfo = resolveField(this.options.modelMeta, this.model, by); + if (fieldInfo && fieldInfo.inheritedFrom) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `groupBy with fields from base type is not supported yet: "${by}"` + ); + } + } + } + + args = deepcopy(args); + + if (args.where) { + args.where = this.buildWhereHierarchy(args.where); + } + + if (this.options.logPrismaQuery) { + this.logger.info(`[delegate] \`groupBy\` ${this.getModelName(this.model)}: ${formatObject(args)}`); + } + return super.groupBy(args); + } + + private checkAggregationArgs(operation: 'aggregate' | 'count' | 'groupBy', args: any) { + if (!args) { + return; + } + + for (const op of ['_count', '_sum', '_avg', '_min', '_max', 'select', 'having']) { + if (args[op] && typeof args[op] === 'object') { + for (const field of Object.keys(args[op])) { + const fieldInfo = resolveField(this.options.modelMeta, this.model, field); + if (fieldInfo?.inheritedFrom) { + throw prismaClientValidationError( + this.prisma, + this.options.prismaModule, + `${operation} with fields from base type is not supported yet: "${field}"` + ); + } + } + } + } + } + + // #endregion + + // #region utils + + private extractSelectInclude(args: any) { + if (!args) { + return undefined; + } + args = deepcopy(args); + return 'select' in args + ? { select: args['select'] } + : 'include' in args + ? { include: args['include'] } + : undefined; + } + + private makeAuxRelationName(model: ModelInfo) { + return `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(model.name)}`; + } + + private getModelName(model: string) { + const info = getModelInfo(this.options.modelMeta, model, true); + return info.name; + } + + private getIdFields(model: string): FieldInfo[] { + const idFields = getIdFields(this.options.modelMeta, model); + if (idFields && idFields.length > 0) { + return idFields; + } + const base = this.getBaseModel(model); + return base ? this.getIdFields(base.name) : []; + } + + private getModelInfo(model: string) { + return getModelInfo(this.options.modelMeta, model, true); + } + + private getBaseModel(model: string) { + const baseNames = getModelInfo(this.options.modelMeta, model, true).baseTypes; + if (!baseNames) { + return undefined; + } + if (baseNames.length > 1) { + throw new Error('Multi-inheritance is not supported'); + } + return this.options.modelMeta.models[lowerCaseFirst(baseNames[0])]; + } + + private involvesDelegateModel(model: string, visited?: Set): boolean { + if (this.isDelegateOrDescendantOfDelegate(model)) { + return true; + } + + visited = visited ?? new Set(); + if (visited.has(model)) { + return false; + } + visited.add(model); + + const modelInfo = getModelInfo(this.options.modelMeta, model, true); + return Object.values(modelInfo.fields).some( + (field) => field.isDataModel && this.involvesDelegateModel(field.type, visited) + ); + } + + private isDelegateOrDescendantOfDelegate(model: string): boolean { + if (isDelegateModel(this.options.modelMeta, model)) { + return true; + } + const baseTypes = getModelInfo(this.options.modelMeta, model)?.baseTypes; + return !!( + baseTypes && + baseTypes.length > 0 && + baseTypes.some((base) => this.isDelegateOrDescendantOfDelegate(base)) + ); + } + + private assembleHierarchy(model: string, entity: any) { + if (!entity || typeof entity !== 'object') { + return entity; + } + + const result: any = {}; + const base = this.getBaseModel(model); + + if (base) { + // merge base fields + const baseRelationName = this.makeAuxRelationName(base); + const baseData = entity[baseRelationName]; + if (baseData && typeof baseData === 'object') { + const baseAssembled = this.assembleHierarchy(base.name, baseData); + Object.assign(result, baseAssembled); + } + } + + const modelInfo = getModelInfo(this.options.modelMeta, model, true); + + for (const field of Object.values(modelInfo.fields)) { + if (field.inheritedFrom) { + // already merged from base + continue; + } + + if (field.name in entity) { + const fieldValue = entity[field.name]; + if (field.isDataModel) { + if (Array.isArray(fieldValue)) { + result[field.name] = fieldValue.map((item) => this.assembleHierarchy(field.type, item)); + } else { + result[field.name] = this.assembleHierarchy(field.type, fieldValue); + } + } else { + result[field.name] = fieldValue; + } + } + } + + return result; + } + + // #endregion + + // #region backup + + private transformWhereHierarchy(where: any, contextModel: ModelInfo, forModel: ModelInfo) { + if (!where || typeof where !== 'object') { + return where; + } + + let curr: ModelInfo | undefined = contextModel; + const inheritStack: ModelInfo[] = []; + while (curr) { + inheritStack.unshift(curr); + curr = this.getBaseModel(curr.name); + } + + let result: any = {}; + for (const [key, value] of Object.entries(where)) { + const fieldInfo = requireField(this.options.modelMeta, contextModel.name, key); + const fieldHierarchy = this.transformFieldHierarchy(fieldInfo, value, contextModel, forModel, inheritStack); + result = deepmerge(result, fieldHierarchy); + } + + return result; + } + + private transformFieldHierarchy( + fieldInfo: FieldInfo, + value: unknown, + contextModel: ModelInfo, + forModel: ModelInfo, + inheritStack: ModelInfo[] + ): any { + const fieldModel = fieldInfo.inheritedFrom ? this.getModelInfo(fieldInfo.inheritedFrom) : contextModel; + if (fieldModel === forModel) { + return { [fieldInfo.name]: value }; + } + + const fieldModelPos = inheritStack.findIndex((m) => m === fieldModel); + const forModelPos = inheritStack.findIndex((m) => m === forModel); + const result: any = {}; + let curr = result; + + if (fieldModelPos > forModelPos) { + // walk down hierarchy + for (let i = forModelPos + 1; i <= fieldModelPos; i++) { + const rel = this.makeAuxRelationName(inheritStack[i]); + curr[rel] = {}; + curr = curr[rel]; + } + } else { + // walk up hierarchy + for (let i = forModelPos - 1; i >= fieldModelPos; i--) { + const rel = this.makeAuxRelationName(inheritStack[i]); + curr[rel] = {}; + curr = curr[rel]; + } + } + + curr[fieldInfo.name] = value; + return result; + } + + // #endregion +} diff --git a/packages/runtime/src/enhancements/policy/logger.ts b/packages/runtime/src/enhancements/logger.ts similarity index 100% rename from packages/runtime/src/enhancements/policy/logger.ts rename to packages/runtime/src/enhancements/logger.ts diff --git a/packages/runtime/src/enhancements/omit.ts b/packages/runtime/src/enhancements/omit.ts index bedbf5458..e05a8a769 100644 --- a/packages/runtime/src/enhancements/omit.ts +++ b/packages/runtime/src/enhancements/omit.ts @@ -21,8 +21,8 @@ export function withOmit(prisma: DbClient, options: Enh } class OmitHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, private readonly options: EnhancementOptions) { - super(prisma, model); + constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + super(prisma, model, options); } // base override diff --git a/packages/runtime/src/enhancements/password.ts b/packages/runtime/src/enhancements/password.ts index db8e3181b..7fef04dd8 100644 --- a/packages/runtime/src/enhancements/password.ts +++ b/packages/runtime/src/enhancements/password.ts @@ -23,8 +23,8 @@ export function withPassword(prisma: DbClient, op } class PasswordHandler extends DefaultPrismaProxyHandler { - constructor(prisma: DbClientContract, model: string, private readonly options: EnhancementOptions) { - super(prisma, model); + constructor(prisma: DbClientContract, model: string, options: EnhancementOptions) { + super(prisma, model, options); } // base override diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 65207abea..1bc60a647 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -16,11 +16,12 @@ import { type FieldInfo, type ModelMeta, } from '../../cross'; -import { DbClientContract, DbOperations, PolicyOperationKind } from '../../types'; +import { type CrudContract, type DbClientContract, PolicyOperationKind } from '../../types'; import type { EnhancementContext, EnhancementOptions } from '../create-enhancement'; +import { Logger } from '../logger'; import { PrismaProxyHandler } from '../proxy'; +import { QueryUtils } from '../query-utils'; import { formatObject, prismaClientValidationError } from '../utils'; -import { Logger } from './logger'; import { PolicyUtil } from './policy-utils'; import { createDeferredPromise } from './promise'; @@ -39,14 +40,11 @@ type FindOperations = 'findUnique' | 'findUniqueOrThrow' | 'findFirst' | 'findFi */ export class PolicyProxyHandler implements PrismaProxyHandler { private readonly logger: Logger; - private readonly utils: PolicyUtil; + private readonly policyUtils: PolicyUtil; private readonly model: string; private readonly modelMeta: ModelMeta; private readonly prismaModule: any; - private readonly logPrismaQuery?: boolean; - - private readonly DEFAULT_TX_MAXWAIT = 100000; - private readonly DEFAULT_TX_TIMEOUT = 100000; + private readonly queryUtils: QueryUtils; constructor( private readonly prisma: DbClient, @@ -57,9 +55,10 @@ export class PolicyProxyHandler implements Pr this.logger = new Logger(prisma); this.model = lowerCaseFirst(model); - ({ modelMeta: this.modelMeta, logPrismaQuery: this.logPrismaQuery, prismaModule: this.prismaModule } = options); + ({ modelMeta: this.modelMeta, prismaModule: this.prismaModule } = options); - this.utils = new PolicyUtil(prisma, options, context, this.shouldLogQuery); + this.policyUtils = new PolicyUtil(prisma, options, context, this.shouldLogQuery); + this.queryUtils = new QueryUtils(prisma, options); } private get modelClient() { @@ -96,7 +95,7 @@ export class PolicyProxyHandler implements Pr ); } return this.findWithFluentCallStubs(args, 'findUniqueOrThrow', true, () => { - throw this.utils.notFound(this.model); + throw this.policyUtils.notFound(this.model); }); } @@ -106,7 +105,7 @@ export class PolicyProxyHandler implements Pr findFirstOrThrow(args: any) { return this.findWithFluentCallStubs(args, 'findFirstOrThrow', true, () => { - throw this.utils.notFound(this.model); + throw this.policyUtils.notFound(this.model); }); } @@ -129,12 +128,15 @@ export class PolicyProxyHandler implements Pr private doFind(args: any, actionName: FindOperations, handleRejection: () => any) { const origArgs = args; - const _args = this.utils.clone(args); - if (!this.utils.injectForRead(this.prisma, this.model, _args)) { + const _args = this.policyUtils.clone(args); + if (!this.policyUtils.injectForRead(this.prisma, this.model, _args)) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`${actionName}\` ${this.model}: unconditionally denied`); + } return handleRejection(); } - this.utils.injectReadCheckSelect(this.model, _args); + this.policyUtils.injectReadCheckSelect(this.model, _args); if (this.shouldLogQuery) { this.logger.info(`[policy] \`${actionName}\` ${this.model}:\n${formatObject(_args)}`); @@ -143,7 +145,7 @@ export class PolicyProxyHandler implements Pr return new Promise((resolve, reject) => { this.modelClient[actionName](_args).then( (value: any) => { - this.utils.postProcessForRead(value, this.model, origArgs); + this.policyUtils.postProcessForRead(value, this.model, origArgs); resolve(value); }, (err: any) => reject(err) @@ -154,14 +156,14 @@ export class PolicyProxyHandler implements Pr // returns a fluent API call function private fluentCall(filter: any, fieldInfo: FieldInfo, rootPromise?: Promise) { return (args: any) => { - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // combine the parent filter with the current one const backLinkField = this.requireBackLink(fieldInfo); const condition = backLinkField.isArray ? { [backLinkField.name]: { some: filter } } : { [backLinkField.name]: { is: filter } }; - args.where = this.utils.and(args.where, condition); + args.where = this.policyUtils.and(args.where, condition); const promise = createDeferredPromise(() => { // Promise for fetching @@ -207,7 +209,7 @@ export class PolicyProxyHandler implements Pr // add fluent API functions to the given promise private addFluentFunctions(promise: any, model: string, filter: any, rootPromise?: Promise) { - const fields = this.utils.getModelFields(model); + const fields = this.policyUtils.getModelFields(model); if (fields) { for (const [field, fieldInfo] of Object.entries(fields)) { if (fieldInfo.isDataModel) { @@ -233,20 +235,25 @@ export class PolicyProxyHandler implements Pr ); } - this.utils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); const origArgs = args; - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // static input policy check for top-level create data - const inputCheck = this.utils.checkInputGuard(this.model, args.data, 'create'); + const inputCheck = this.policyUtils.checkInputGuard(this.model, args.data, 'create'); if (inputCheck === false) { - throw this.utils.deniedByPolicy(this.model, 'create', undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION); + throw this.policyUtils.deniedByPolicy( + this.model, + 'create', + undefined, + CrudFailureReason.ACCESS_POLICY_VIOLATION + ); } const hasNestedCreateOrConnect = await this.hasNestedCreateOrConnect(args); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { if ( // MUST check true here since inputCheck can be undefined (meaning static input check not possible) inputCheck === true && @@ -259,7 +266,7 @@ export class PolicyProxyHandler implements Pr this.validateCreateInputSchema(this.model, args.data); // make a create args only containing data and ID selection - const createArgs: any = { data: args.data, select: this.utils.makeIdSelection(this.model) }; + const createArgs: any = { data: args.data, select: this.policyUtils.makeIdSelection(this.model) }; if (this.shouldLogQuery) { this.logger.info(`[policy] \`create\` ${this.model}: ${formatObject(createArgs)}`); @@ -267,7 +274,7 @@ export class PolicyProxyHandler implements Pr const result = await tx[this.model].create(createArgs); // filter the read-back data - return this.utils.readBack(tx, this.model, 'create', args, result); + return this.policyUtils.readBack(tx, this.model, 'create', args, result); } else { // proceed with a complex create and collect post-write checks const { result, postWriteChecks } = await this.doCreate(this.model, args, tx); @@ -276,7 +283,7 @@ export class PolicyProxyHandler implements Pr await this.runPostWriteChecks(postWriteChecks, tx); // filter the read-back data - return this.utils.readBack(tx, this.model, 'create', origArgs, result); + return this.policyUtils.readBack(tx, this.model, 'create', origArgs, result); } }); @@ -288,7 +295,7 @@ export class PolicyProxyHandler implements Pr } // create with nested write - private async doCreate(model: string, args: any, db: Record) { + private async doCreate(model: string, args: any, db: CrudContract) { // record id fields involved in the nesting context const idSelections: Array<{ path: FieldInfo[]; ids: string[] }> = []; const pushIdFields = (model: string, context: NestedWriteVisitorContext) => { @@ -323,12 +330,12 @@ export class PolicyProxyHandler implements Pr connectOrCreate: async (model, args, context) => { if (!args.where) { - throw this.utils.validationError(`'where' field is required for connectOrCreate`); + throw this.policyUtils.validationError(`'where' field is required for connectOrCreate`); } this.validateCreateInputSchema(model, args.create); - const existing = await this.utils.checkExistence(db, model, args.where); + const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect case if (context.field?.backLink) { @@ -336,7 +343,7 @@ export class PolicyProxyHandler implements Pr if (backLinkField?.isRelationOwner) { // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args.where, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, args.where, 'update', db, args); } } @@ -370,18 +377,18 @@ export class PolicyProxyHandler implements Pr connect: async (model, args, context) => { if (!args || typeof args !== 'object' || Object.keys(args).length === 0) { - throw this.utils.validationError(`'connect' field must be an non-empty object`); + throw this.policyUtils.validationError(`'connect' field must be an non-empty object`); } if (context.field?.backLink) { const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); if (backLinkField?.isRelationOwner) { // check existence - await this.utils.checkExistence(db, model, args, true); + await this.policyUtils.checkExistence(db, model, args, true); // the target side of relation owns the relation, // check if it's updatable - await this.utils.checkPolicyForUnique(model, args, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, args); } } }, @@ -426,7 +433,7 @@ export class PolicyProxyHandler implements Pr }); // return only the ids of the top-level entity - const ids = this.utils.getEntityIds(this.model, result); + const ids = this.policyUtils.getEntityIds(this.model, result); return { result: ids, postWriteChecks: [...postCreateChecks.values()] }; } @@ -463,11 +470,11 @@ export class PolicyProxyHandler implements Pr // Validates the given create payload against Zod schema if any private validateCreateInputSchema(model: string, data: any) { - const schema = this.utils.getZodSchema(model, 'create'); + const schema = this.policyUtils.getZodSchema(model, 'create'); if (schema) { const parseResult = schema.safeParse(data); if (!parseResult.success) { - throw this.utils.deniedByPolicy( + throw this.policyUtils.deniedByPolicy( model, 'create', `input failed validation: ${fromZodError(parseResult.error)}`, @@ -490,16 +497,16 @@ export class PolicyProxyHandler implements Pr ); } - this.utils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // do static input validation and check if post-create checks are needed let needPostCreateCheck = false; for (const item of enumerate(args.data)) { - const inputCheck = this.utils.checkInputGuard(this.model, item, 'create'); + const inputCheck = this.policyUtils.checkInputGuard(this.model, item, 'create'); if (inputCheck === false) { - throw this.utils.deniedByPolicy( + throw this.policyUtils.deniedByPolicy( this.model, 'create', undefined, @@ -518,7 +525,7 @@ export class PolicyProxyHandler implements Pr return this.modelClient.createMany(args); } else { // create entities in a transaction with post-create checks - return this.transaction(async (tx) => { + return this.queryUtils.transaction(this.prisma, async (tx) => { const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); // post-create check await this.runPostWriteChecks(postWriteChecks, tx); @@ -527,11 +534,7 @@ export class PolicyProxyHandler implements Pr } } - private async doCreateMany( - model: string, - args: { data: any; skipDuplicates?: boolean }, - db: Record - ) { + private async doCreateMany(model: string, args: { data: any; skipDuplicates?: boolean }, db: CrudContract) { // We can't call the native "createMany" because we can't get back what was created // for post-create checks. Instead, do a "create" for each item and collect the results. @@ -549,7 +552,7 @@ export class PolicyProxyHandler implements Pr if (this.shouldLogQuery) { this.logger.info(`[policy] \`create\` for \`createMany\` ${model}: ${formatObject(item)}`); } - return await db[model].create({ select: this.utils.makeIdSelection(model), data: item }); + return await db[model].create({ select: this.policyUtils.makeIdSelection(model), data: item }); }) ); @@ -566,18 +569,18 @@ export class PolicyProxyHandler implements Pr }; } - private async hasDuplicatedUniqueConstraint(model: string, createData: any, db: Record) { + private async hasDuplicatedUniqueConstraint(model: string, createData: any, db: CrudContract) { // check unique constraint conflicts // we can't rely on try/catch/ignore constraint violation error: https://github.com/prisma/prisma/issues/20496 // TODO: for simple cases we should be able to translate it to an `upsert` with empty `update` payload // for each unique constraint, check if the input item has all fields set, and if so, check if // an entity already exists, and ignore accordingly - const uniqueConstraints = this.utils.getUniqueConstraints(model); + const uniqueConstraints = this.policyUtils.getUniqueConstraints(model); for (const constraint of Object.values(uniqueConstraints)) { if (constraint.fields.every((f) => createData[f] !== undefined)) { const uniqueFilter = constraint.fields.reduce((acc, f) => ({ ...acc, [f]: createData[f] }), {}); - const existing = await this.utils.checkExistence(db, model, uniqueFilter); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter); if (existing) { return true; } @@ -615,9 +618,9 @@ export class PolicyProxyHandler implements Pr ); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { // proceed with nested writes and collect post-write checks const { result, postWriteChecks } = await this.doUpdate(args, tx); @@ -625,7 +628,7 @@ export class PolicyProxyHandler implements Pr await this.runPostWriteChecks(postWriteChecks, tx); // filter the read-back data - return this.utils.readBack(tx, this.model, 'update', args, result); + return this.policyUtils.readBack(tx, this.model, 'update', args, result); }); if (error) { @@ -635,17 +638,17 @@ export class PolicyProxyHandler implements Pr } } - private async doUpdate(args: any, db: Record) { + private async doUpdate(args: any, db: CrudContract) { // collected post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; // registers a post-update check task const _registerPostUpdateCheck = async (model: string, uniqueFilter: any) => { // both "post-update" rules and Zod schemas require a post-update check - if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { + if (this.policyUtils.hasAuthGuard(model, 'postUpdate') || this.policyUtils.getZodSchema(model)) { // select pre-update field values let preValue: any; - const preValueSelect = this.utils.getPreValueSelect(model); + const preValueSelect = this.policyUtils.getPreValueSelect(model); if (preValueSelect && Object.keys(preValueSelect).length > 0) { preValue = await db[model].findFirst({ where: uniqueFilter, select: preValueSelect }); } @@ -672,7 +675,7 @@ export class PolicyProxyHandler implements Pr const unsafe = this.isUnsafeMutate(model, args); // handles the connection to upstream entity - const reversedQuery = this.utils.buildReversedQuery(context, true, unsafe); + const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe); if ((!unsafe || context.field.isRelationOwner) && reversedQuery[context.field.backLink]) { // if mutation is safe, or current field owns the relation (so the other side has no fk), // and the reverse query contains the back link, then we can build a "connect" with it @@ -707,7 +710,7 @@ export class PolicyProxyHandler implements Pr // for example when it's nested inside a one-to-one update const upstreamQuery = { where: reversedQuery[backLinkField.name], - select: this.utils.makeIdSelection(backLinkField.type), + select: this.policyUtils.makeIdSelection(backLinkField.type), }; // fetch the upstream entity @@ -757,8 +760,8 @@ export class PolicyProxyHandler implements Pr const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { if (context.field?.backLink) { - const backLinkField = this.utils.getModelField(model, context.field.backLink); - if (backLinkField.isRelationOwner) { + const backLinkField = this.policyUtils.getModelField(model, context.field.backLink); + if (backLinkField?.isRelationOwner) { // update happens on the related model, require updatable, // translate args to foreign keys so field-level policies can be checked const checkArgs: any = {}; @@ -770,7 +773,7 @@ export class PolicyProxyHandler implements Pr } } } - await this.utils.checkPolicyForUnique(model, args, 'update', db, checkArgs); + await this.policyUtils.checkPolicyForUnique(model, args, 'update', db, checkArgs); // register post-update check await _registerPostUpdateCheck(model, args); @@ -782,10 +785,10 @@ export class PolicyProxyHandler implements Pr const visitor = new NestedWriteVisitor(this.modelMeta, { update: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // handle not-found - const existing = await this.utils.checkExistence(db, model, uniqueFilter, true); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter, true); // check if the update actually writes to this model let thisModelUpdate = false; @@ -808,13 +811,13 @@ export class PolicyProxyHandler implements Pr } if (thisModelUpdate) { - this.utils.tryReject(db, this.model, 'update'); + this.policyUtils.tryReject(db, this.model, 'update'); // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // handles the case where id fields are updated - const ids = this.utils.clone(existing); + const ids = this.policyUtils.clone(existing); for (const key of Object.keys(existing)) { const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key]; if ( @@ -833,15 +836,15 @@ export class PolicyProxyHandler implements Pr updateMany: async (model, args, context) => { // prepare for post-update check - if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { - let select = this.utils.makeIdSelection(model); - const preValueSelect = this.utils.getPreValueSelect(model); + if (this.policyUtils.hasAuthGuard(model, 'postUpdate') || this.policyUtils.getZodSchema(model)) { + let select = this.policyUtils.makeIdSelection(model); + const preValueSelect = this.policyUtils.getPreValueSelect(model); if (preValueSelect) { select = { ...select, ...preValueSelect }; } - const reversedQuery = this.utils.buildReversedQuery(context); + const reversedQuery = this.policyUtils.buildReversedQuery(context); const currentSetQuery = { select, where: reversedQuery }; - this.utils.injectAuthGuardAsWhere(db, currentSetQuery, model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(db, currentSetQuery, model, 'read'); if (this.shouldLogQuery) { this.logger.info( @@ -860,15 +863,15 @@ export class PolicyProxyHandler implements Pr ); } - const updateGuard = this.utils.getAuthGuard(db, model, 'update'); - if (this.utils.isTrue(updateGuard) || this.utils.isFalse(updateGuard)) { + const updateGuard = this.policyUtils.getAuthGuard(db, model, 'update'); + if (this.policyUtils.isTrue(updateGuard) || this.policyUtils.isFalse(updateGuard)) { // injects simple auth guard into where clause - this.utils.injectAuthGuardAsWhere(db, args, model, 'update'); + this.policyUtils.injectAuthGuardAsWhere(db, args, model, 'update'); } else { // we have to process `updateMany` separately because the guard may contain // filters using relation fields which are not allowed in nested `updateMany` - const reversedQuery = this.utils.buildReversedQuery(context); - const updateWhere = this.utils.and(reversedQuery, updateGuard); + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const updateWhere = this.policyUtils.and(reversedQuery, updateGuard); if (this.shouldLogQuery) { this.logger.info( `[policy] \`updateMany\` ${model}:\n${formatObject({ @@ -906,15 +909,15 @@ export class PolicyProxyHandler implements Pr upsert: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // branch based on if the update target exists - const existing = await this.utils.checkExistence(db, model, uniqueFilter); + const existing = await this.policyUtils.checkExistence(db, model, uniqueFilter); if (existing) { // update case // check pre-update guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args); // register post-update check await _registerPostUpdateCheck(model, uniqueFilter); @@ -943,7 +946,7 @@ export class PolicyProxyHandler implements Pr connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists - const existing = await this.utils.checkExistence(db, model, args.where); + const existing = await this.policyUtils.checkExistence(db, model, args.where); if (existing) { // connect await _connectDisconnect(model, args.where, context); @@ -957,9 +960,9 @@ export class PolicyProxyHandler implements Pr set: async (model, args, context) => { // find the set of items to be replaced - const reversedQuery = this.utils.buildReversedQuery(context); + const reversedQuery = this.policyUtils.buildReversedQuery(context); const findCurrSetArgs = { - select: this.utils.makeIdSelection(model), + select: this.policyUtils.makeIdSelection(model), where: reversedQuery, }; if (this.shouldLogQuery) { @@ -976,25 +979,25 @@ export class PolicyProxyHandler implements Pr delete: async (model, args, context) => { // build a unique query including upstream conditions - const uniqueFilter = this.utils.buildReversedQuery(context); + const uniqueFilter = this.policyUtils.buildReversedQuery(context); // handle not-found - await this.utils.checkExistence(db, model, uniqueFilter, true); + await this.policyUtils.checkExistence(db, model, uniqueFilter, true); // check delete guard - await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); + await this.policyUtils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args); }, deleteMany: async (model, args, context) => { - const guard = await this.utils.getAuthGuard(db, model, 'delete'); - if (this.utils.isTrue(guard) || this.utils.isFalse(guard)) { + const guard = await this.policyUtils.getAuthGuard(db, model, 'delete'); + if (this.policyUtils.isTrue(guard) || this.policyUtils.isFalse(guard)) { // inject simple auth guard - context.parent.deleteMany = this.utils.and(args, guard); + context.parent.deleteMany = this.policyUtils.and(args, guard); } else { // we have to process `deleteMany` separately because the guard may contain // filters using relation fields which are not allowed in nested `deleteMany` - const reversedQuery = this.utils.buildReversedQuery(context); - const deleteWhere = this.utils.and(reversedQuery, guard); + const reversedQuery = this.policyUtils.buildReversedQuery(context); + const deleteWhere = this.policyUtils.and(reversedQuery, guard); if (this.shouldLogQuery) { this.logger.info(`[policy] \`deleteMany\` ${model}:\n${formatObject({ where: deleteWhere })}`); } @@ -1013,7 +1016,7 @@ export class PolicyProxyHandler implements Pr const result = await db[this.model].update({ where: args.where, data: args.data, - select: this.utils.makeIdSelection(this.model), + select: this.policyUtils.makeIdSelection(this.model), }); return { result, postWriteChecks }; @@ -1025,7 +1028,7 @@ export class PolicyProxyHandler implements Pr } for (const k of Object.keys(args)) { const field = resolveField(this.modelMeta, model, k); - if (this.isAutoIncrementIdField(field) || field?.isForeignKey) { + if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) { return true; } } @@ -1048,23 +1051,23 @@ export class PolicyProxyHandler implements Pr ); } - this.utils.tryReject(this.prisma, this.model, 'update'); + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.utils.clone(args); - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); + args = this.policyUtils.clone(args); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'update'); - if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { + if (this.policyUtils.hasAuthGuard(this.model, 'postUpdate') || this.policyUtils.getZodSchema(this.model)) { // use a transaction to do post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; - return this.transaction(async (tx) => { + return this.queryUtils.transaction(this.prisma, async (tx) => { // collect pre-update values - let select = this.utils.makeIdSelection(this.model); - const preValueSelect = this.utils.getPreValueSelect(this.model); + let select = this.policyUtils.makeIdSelection(this.model); + const preValueSelect = this.policyUtils.getPreValueSelect(this.model); if (preValueSelect) { select = { ...select, ...preValueSelect }; } const currentSetQuery = { select, where: args.where }; - this.utils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(tx, currentSetQuery, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); @@ -1075,7 +1078,7 @@ export class PolicyProxyHandler implements Pr ...currentSet.map((preValue) => ({ model: this.model, operation: 'postUpdate' as PolicyOperationKind, - uniqueFilter: this.utils.getEntityIds(this.model, preValue), + uniqueFilter: this.policyUtils.getEntityIds(this.model, preValue), preValue: preValueSelect ? preValue : undefined, })) ); @@ -1123,28 +1126,28 @@ export class PolicyProxyHandler implements Pr ); } - this.utils.tryReject(this.prisma, this.model, 'create'); - this.utils.tryReject(this.prisma, this.model, 'update'); + this.policyUtils.tryReject(this.prisma, this.model, 'create'); + this.policyUtils.tryReject(this.prisma, this.model, 'update'); - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // We can call the native "upsert" because we can't tell if an entity was created or updated // for doing post-write check accordingly. Instead, decompose it into create or update. - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { const { where, create, update, ...rest } = args; - const existing = await this.utils.checkExistence(tx, this.model, args.where); + const existing = await this.policyUtils.checkExistence(tx, this.model, args.where); if (existing) { // update case const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); await this.runPostWriteChecks(postWriteChecks, tx); - return this.utils.readBack(tx, this.model, 'update', args, result); + return this.policyUtils.readBack(tx, this.model, 'update', args, result); } else { // create case const { result, postWriteChecks } = await this.doCreate(this.model, { data: create, ...rest }, tx); await this.runPostWriteChecks(postWriteChecks, tx); - return this.utils.readBack(tx, this.model, 'create', args, result); + return this.policyUtils.readBack(tx, this.model, 'create', args, result); } }); @@ -1174,19 +1177,19 @@ export class PolicyProxyHandler implements Pr ); } - this.utils.tryReject(this.prisma, this.model, 'delete'); + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); - const { result, error } = await this.transaction(async (tx) => { + const { result, error } = await this.queryUtils.transaction(this.prisma, async (tx) => { // do a read-back before delete - const r = await this.utils.readBack(tx, this.model, 'delete', args, args.where); + const r = await this.policyUtils.readBack(tx, this.model, 'delete', args, args.where); const error = r.error; const read = r.result; // check existence - await this.utils.checkExistence(tx, this.model, args.where, true); + await this.policyUtils.checkExistence(tx, this.model, args.where, true); // inject delete guard - await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); + await this.policyUtils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args); // proceed with the deletion if (this.shouldLogQuery) { @@ -1205,11 +1208,11 @@ export class PolicyProxyHandler implements Pr } async deleteMany(args: any) { - this.utils.tryReject(this.prisma, this.model, 'delete'); + this.policyUtils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions args = args ?? {}; - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'delete'); // conduct the deletion if (this.shouldLogQuery) { @@ -1227,10 +1230,10 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // inject policy conditions - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); @@ -1243,10 +1246,10 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required'); } - args = this.utils.clone(args); + args = this.policyUtils.clone(args); // inject policy conditions - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); @@ -1256,8 +1259,8 @@ export class PolicyProxyHandler implements Pr async count(args: any) { // inject policy conditions - args = args ? this.utils.clone(args) : {}; - this.utils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); + args = args ? this.policyUtils.clone(args) : {}; + this.policyUtils.injectAuthGuardAsWhere(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); @@ -1270,8 +1273,8 @@ export class PolicyProxyHandler implements Pr //#region Subscribe (Prisma Pulse) async subscribe(args: any) { - const readGuard = this.utils.getAuthGuard(this.prisma, this.model, 'read'); - if (this.utils.isTrue(readGuard)) { + const readGuard = this.policyUtils.getAuthGuard(this.prisma, this.model, 'read'); + if (this.policyUtils.isTrue(readGuard)) { // no need to inject if (this.shouldLogQuery) { this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); @@ -1290,22 +1293,22 @@ export class PolicyProxyHandler implements Pr // include all args = { create: {}, update: {}, delete: {} }; } else { - args = this.utils.clone(args); + args = this.policyUtils.clone(args); } } // inject into subscribe conditions if (args.create) { - args.create.after = this.utils.and(args.create.after, readGuard); + args.create.after = this.policyUtils.and(args.create.after, readGuard); } if (args.update) { - args.update.after = this.utils.and(args.update.after, readGuard); + args.update.after = this.policyUtils.and(args.update.after, readGuard); } if (args.delete) { - args.delete.before = this.utils.and(args.delete.before, readGuard); + args.delete.before = this.policyUtils.and(args.delete.before, readGuard); } if (this.shouldLogQuery) { @@ -1322,23 +1325,10 @@ export class PolicyProxyHandler implements Pr return !!this.options?.logPrismaQuery && this.logger.enabled('info'); } - private transaction(action: (tx: Record) => Promise) { - if (this.prisma['$transaction']) { - return this.prisma.$transaction((tx) => action(tx), { - maxWait: this.options.transactionMaxWait, - timeout: this.options.transactionTimeout, - isolationLevel: this.options.transactionIsolationLevel, - }); - } else { - // already in transaction, don't nest - return action(this.prisma); - } - } - - private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record) { + private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: CrudContract) { await Promise.all( postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) => - this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) + this.policyUtils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue) ) ); } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index b7e3448c8..50ef3a3bc 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -16,33 +16,19 @@ import { PRE_UPDATE_VALUE_SELECTOR, PrismaErrorCode, } from '../../constants'; -import { - enumerate, - getFields, - getIdFields, - getModelFields, - resolveField, - zip, - type FieldInfo, - type ModelMeta, - type NestedWriteVisitorContext, -} from '../../cross'; -import { AuthUser, DbClientContract, DbOperations, PolicyOperationKind } from '../../types'; +import { enumerate, getFields, getModelFields, resolveField, zip, type FieldInfo, type ModelMeta } from '../../cross'; +import { AuthUser, CrudContract, DbClientContract, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import type { EnhancementContext, EnhancementOptions } from '../create-enhancement'; +import { Logger } from '../logger'; +import { QueryUtils } from '../query-utils'; import type { InputCheckFunc, PolicyDef, ReadFieldCheckFunc, ZodSchemas } from '../types'; -import { - formatObject, - prismaClientKnownRequestError, - prismaClientUnknownRequestError, - prismaClientValidationError, -} from '../utils'; -import { Logger } from './logger'; +import { formatObject, prismaClientKnownRequestError } from '../utils'; /** * Access policy enforcement utilities */ -export class PolicyUtil { +export class PolicyUtil extends QueryUtils { private readonly logger: Logger; private readonly modelMeta: ModelMeta; private readonly policy: PolicyDef; @@ -56,6 +42,8 @@ export class PolicyUtil { context?: EnhancementContext, private readonly shouldLogQuery = false ) { + super(db, options); + this.logger = new Logger(db); this.user = context?.user; @@ -248,7 +236,7 @@ export class PolicyUtil { * @returns true if operation is unconditionally allowed, false if unconditionally denied, * otherwise returns a guard object */ - getAuthGuard(db: Record, model: string, operation: PolicyOperationKind, preValue?: any) { + getAuthGuard(db: CrudContract, model: string, operation: PolicyOperationKind, preValue?: any) { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); @@ -269,7 +257,7 @@ export class PolicyUtil { /** * Get field-level read auth guard that overrides the model-level */ - getFieldOverrideReadAuthGuard(db: Record, model: string, field: string) { + getFieldOverrideReadAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field}`]; @@ -289,7 +277,7 @@ export class PolicyUtil { /** * Get field-level update auth guard */ - getFieldUpdateAuthGuard(db: Record, model: string, field: string) { + getFieldUpdateAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`]; @@ -309,7 +297,7 @@ export class PolicyUtil { /** * Get field-level update auth guard that overrides the model-level */ - getFieldOverrideUpdateAuthGuard(db: Record, model: string, field: string) { + getFieldOverrideUpdateAuthGuard(db: CrudContract, model: string, field: string) { const guard = this.requireGuard(model); const provider = guard[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field}`]; @@ -365,7 +353,7 @@ export class PolicyUtil { /** * Injects model auth guard as where clause. */ - injectAuthGuardAsWhere(db: Record, args: any, model: string, operation: PolicyOperationKind) { + injectAuthGuardAsWhere(db: CrudContract, args: any, model: string, operation: PolicyOperationKind) { let guard = this.getAuthGuard(db, model, operation); if (operation === 'update' && args) { @@ -413,7 +401,7 @@ export class PolicyUtil { } private injectGuardForRelationFields( - db: Record, + db: CrudContract, model: string, payload: any, operation: PolicyOperationKind @@ -437,7 +425,7 @@ export class PolicyUtil { } private injectGuardForToManyField( - db: Record, + db: CrudContract, fieldInfo: FieldInfo, payload: { some?: any; every?: any; none?: any }, operation: PolicyOperationKind @@ -471,7 +459,7 @@ export class PolicyUtil { } private injectGuardForToOneField( - db: Record, + db: CrudContract, fieldInfo: FieldInfo, payload: { is?: any; isNot?: any } & Record, operation: PolicyOperationKind @@ -501,7 +489,7 @@ export class PolicyUtil { /** * Injects auth guard for read operations. */ - injectForRead(db: Record, model: string, args: any) { + injectForRead(db: CrudContract, model: string, args: any) { // make select and include visible to the injection const injected: any = { select: args.select, include: args.include }; if (!this.injectAuthGuardAsWhere(db, injected, model, 'read')) { @@ -539,111 +527,14 @@ export class PolicyUtil { return true; } - // flatten unique constraint filters - private flattenGeneratedUniqueField(model: string, args: any) { - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } - const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)]; - if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { - for (const [field, value] of Object.entries(args)) { - if ( - uniqueConstraints[field] && - uniqueConstraints[field].fields.length > 1 && - typeof value === 'object' - ) { - // multi-field unique constraint, flatten it - delete args[field]; - if (value) { - for (const [f, v] of Object.entries(value)) { - args[f] = v; - } - } - } - } - } - } - /** * Gets unique constraints for the given model. */ getUniqueConstraints(model: string) { - return this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)] ?? {}; - } - - /** - * Builds a reversed query for the given nested path. - */ - buildReversedQuery(context: NestedWriteVisitorContext, forMutationPayload = false, unsafeOperation = false) { - let result, currQuery: any; - let currField: FieldInfo | undefined; - - for (let i = context.nestingPath.length - 1; i >= 0; i--) { - const { field, model, where } = context.nestingPath[i]; - - // never modify the original where because it's shared in the structure - const visitWhere = { ...where }; - if (model && where) { - // make sure composite unique condition is flattened - this.flattenGeneratedUniqueField(model, visitWhere); - } - - if (!result) { - // first segment (bottom), just use its where clause - result = currQuery = { ...visitWhere }; - currField = field; - } else { - if (!currField) { - throw this.unknownError(`missing field in nested path`); - } - if (!currField.backLink) { - throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); - } - - const backLinkField = this.getModelField(currField.type, currField.backLink); - if (!backLinkField) { - throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); - } - - if (backLinkField.isArray && !forMutationPayload) { - // many-side of relationship, wrap with "some" query - currQuery[currField.backLink] = { some: { ...visitWhere } }; - currQuery = currQuery[currField.backLink].some; - } else { - const fkMapping = where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping; - - // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) - const shouldPreserveRelationCondition = - // doing a mutation - forMutationPayload && - // and it's a safe mutate - !unsafeOperation && - // and the current segment is the direct parent (the last one is the mutate itself), - // the relation condition should be preserved and will be converted to a "connect" later - i === context.nestingPath.length - 2; - - if (fkMapping && !shouldPreserveRelationCondition) { - // turn relation condition into foreign key condition, e.g.: - // { user: { id: 1 } } => { userId: 1 } - for (const [r, fk] of Object.entries(fkMapping)) { - currQuery[fk] = visitWhere[r]; - } - - if (i > 0) { - // prepare for the next segment - currQuery[currField.backLink] = {}; - } - } else { - // preserve the original structure - currQuery[currField.backLink] = { ...visitWhere }; - } - currQuery = currQuery[currField.backLink]; - } - currField = field; - } - } - return result; + return this.modelMeta.models[lowerCaseFirst(model)]?.uniqueConstraints ?? {}; } - private injectNestedReadConditions(db: Record, model: string, args: any): any[] { + private injectNestedReadConditions(db: CrudContract, model: string, args: any): any[] { const injectTarget = args.select ?? args.include; if (!injectTarget) { return []; @@ -736,7 +627,7 @@ export class PolicyUtil { model: string, uniqueFilter: any, operation: PolicyOperationKind, - db: Record, + db: CrudContract, args: any, preValue?: any ) { @@ -830,7 +721,7 @@ export class PolicyUtil { } } - private getFieldReadGuards(db: Record, model: string, args: { select?: any; include?: any }) { + private getFieldReadGuards(db: CrudContract, model: string, args: { select?: any; include?: any }) { const allFields = Object.values(getFields(this.modelMeta, model)); // all scalar fields by default @@ -853,7 +744,7 @@ export class PolicyUtil { return this.and(...allFieldGuards); } - private getFieldUpdateGuards(db: Record, model: string, args: any) { + private getFieldUpdateGuards(db: CrudContract, model: string, args: any) { const allFieldGuards = []; const allOverrideFieldGuards = []; @@ -912,7 +803,7 @@ export class PolicyUtil { /** * Tries rejecting a request based on static "false" policy. */ - tryReject(db: Record, model: string, operation: PolicyOperationKind) { + tryReject(db: CrudContract, model: string, operation: PolicyOperationKind) { const guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation, undefined, CrudFailureReason.ACCESS_POLICY_VIOLATION); @@ -922,12 +813,7 @@ export class PolicyUtil { /** * Checks if a model exists given a unique filter. */ - async checkExistence( - db: Record, - model: string, - uniqueFilter: any, - throwIfNotFound = false - ): Promise { + async checkExistence(db: CrudContract, model: string, uniqueFilter: any, throwIfNotFound = false): Promise { uniqueFilter = this.clone(uniqueFilter); this.flattenGeneratedUniqueField(model, uniqueFilter); @@ -948,7 +834,7 @@ export class PolicyUtil { * Returns an entity given a unique filter with read policy checked. Reject if not readable. */ async readBack( - db: Record, + db: CrudContract, model: string, operation: PolicyOperationKind, selectInclude: { select?: any; include?: any }, @@ -1059,7 +945,7 @@ export class PolicyUtil { } private makeAllScalarFieldSelect(model: string): any { - const fields = this.modelMeta.fields[lowerCaseFirst(model)]; + const fields = this.getModelFields(model); const result: any = {}; if (fields) { Object.entries(fields).forEach(([k, v]) => { @@ -1106,16 +992,6 @@ export class PolicyUtil { }); } - validationError(message: string) { - return prismaClientValidationError(this.db, this.prismaModule, message); - } - - unknownError(message: string) { - return prismaClientUnknownRequestError(this.db, this.prismaModule, message, { - clientVersion: getVersion(), - }); - } - //#endregion //#region Misc @@ -1264,22 +1140,6 @@ export class PolicyUtil { } } - /** - * Gets information for all fields of a model. - */ - getModelFields(model: string) { - model = lowerCaseFirst(model); - return this.modelMeta.fields[model]; - } - - /** - * Gets information for a specific model field. - */ - getModelField(model: string, field: string) { - model = lowerCaseFirst(model); - return this.modelMeta.fields[model]?.[field]; - } - /** * Clones an object and makes sure it's not empty. */ @@ -1300,33 +1160,6 @@ export class PolicyUtil { }, {} as any); } - /** - * Gets "id" fields for a given model. - */ - getIdFields(model: string) { - return getIdFields(this.modelMeta, model, true); - } - - /** - * Gets id field values from an entity. - */ - getEntityIds(model: string, entityData: any) { - const idFields = this.getIdFields(model); - const result: Record = {}; - for (const idField of idFields) { - result[idField.name] = entityData[idField.name]; - } - return result; - } - - /** - * Creates a selection object for id fields for the given model. - */ - makeIdSelection(model: string) { - const idFields = this.getIdFields(model); - return Object.assign({}, ...idFields.map((f) => ({ [f.name]: true }))); - } - private mergeWhereClause(where: any, extra: any) { if (!where) { throw new Error('invalid where clause'); diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index c735d595a..e0302f7e9 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -3,6 +3,7 @@ import { PRISMA_PROXY_ENHANCER } from '../constants'; import type { ModelMeta } from '../cross'; import type { DbClientContract } from '../types'; +import { EnhancementOptions } from './create-enhancement'; import { createDeferredPromise } from './policy/promise'; /** @@ -31,7 +32,7 @@ export interface PrismaProxyHandler { create(args: any): Promise; - createMany(args: any, skipDuplicates?: boolean): Promise; + createMany(args: { data: any; skipDuplicates?: boolean }): Promise; update(args: any): Promise; @@ -63,7 +64,11 @@ export type PrismaProxyActions = keyof PrismaProxyHandler; * methods to allow more easily inject custom logic. */ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { - constructor(protected readonly prisma: DbClientContract, protected readonly model: string) {} + constructor( + protected readonly prisma: DbClientContract, + protected readonly model: string, + protected readonly options: EnhancementOptions + ) {} async findUnique(args: any): Promise { args = await this.preprocessArgs('findUnique', args); @@ -101,9 +106,9 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.processResultEntity(r); } - async createMany(args: any, skipDuplicates?: boolean | undefined): Promise<{ count: number }> { + async createMany(args: { data: any; skipDuplicates?: boolean }): Promise<{ count: number }> { args = await this.preprocessArgs('createMany', args); - return this.prisma[this.model].createMany(args, skipDuplicates); + return this.prisma[this.model].createMany(args); } async update(args: any): Promise { @@ -182,7 +187,7 @@ export function makeProxy( name = 'unnamed_enhancer', errorTransformer?: ErrorTransformer ) { - const models = Object.keys(modelMeta.fields).map((k) => k.toLowerCase()); + const models = Object.keys(modelMeta.models).map((k) => k.toLowerCase()); const proxy = new Proxy(prisma, { get: (target: any, prop: string | symbol, receiver: any) => { diff --git a/packages/runtime/src/enhancements/query-utils.ts b/packages/runtime/src/enhancements/query-utils.ts new file mode 100644 index 000000000..f92353081 --- /dev/null +++ b/packages/runtime/src/enhancements/query-utils.ts @@ -0,0 +1,172 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { + FieldInfo, + NestedWriteVisitorContext, + getIdFields, + getModelInfo, + getUniqueConstraints, + resolveField, +} from '../cross'; +import { CrudContract, DbClientContract } from '../types'; +import { getVersion } from '../version'; +import { EnhancementOptions } from './create-enhancement'; +import { prismaClientUnknownRequestError, prismaClientValidationError } from './utils'; + +export class QueryUtils { + constructor(private readonly prisma: DbClientContract, private readonly options: EnhancementOptions) {} + + getIdFields(model: string) { + return getIdFields(this.options.modelMeta, model, true); + } + + makeIdSelection(model: string) { + const idFields = this.getIdFields(model); + return Object.assign({}, ...idFields.map((f) => ({ [f.name]: true }))); + } + + getEntityIds(model: string, entityData: any) { + const idFields = this.getIdFields(model); + const result: Record = {}; + for (const idField of idFields) { + result[idField.name] = entityData[idField.name]; + } + return result; + } + + /** + * Initiates a transaction. + */ + transaction(db: CrudContract, action: (tx: CrudContract) => Promise) { + const fullDb = db as DbClientContract; + if (fullDb['$transaction']) { + return fullDb.$transaction( + (tx) => { + (tx as any)[Symbol.for('nodejs.util.inspect.custom')] = 'PrismaClient$tx'; + return action(tx); + }, + { + maxWait: this.options.transactionMaxWait, + timeout: this.options.transactionTimeout, + isolationLevel: this.options.transactionIsolationLevel, + } + ); + } else { + // already in transaction, don't nest + return action(db); + } + } + + buildReversedQuery(context: NestedWriteVisitorContext, mutating = false, unsafeOperation = false) { + let result, currQuery: any; + let currField: FieldInfo | undefined; + + for (let i = context.nestingPath.length - 1; i >= 0; i--) { + const { field, model, where } = context.nestingPath[i]; + + // never modify the original where because it's shared in the structure + const visitWhere = { ...where }; + if (model && where) { + // make sure composite unique condition is flattened + this.flattenGeneratedUniqueField(model, visitWhere); + } + + if (!result) { + // first segment (bottom), just use its where clause + result = currQuery = { ...visitWhere }; + currField = field; + } else { + if (!currField) { + throw this.unknownError(`missing field in nested path`); + } + if (!currField.backLink) { + throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); + } + + const backLinkField = this.getModelField(currField.type, currField.backLink); + if (!backLinkField) { + throw this.unknownError(`missing backLink field ${currField.backLink} in ${currField.type}`); + } + + if (backLinkField.isArray && !mutating) { + // many-side of relationship, wrap with "some" query + currQuery[currField.backLink] = { some: { ...visitWhere } }; + currQuery = currQuery[currField.backLink].some; + } else { + const fkMapping = where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping; + + // calculate if we should preserve the relation condition (e.g., { user: { id: 1 } }) + const shouldPreserveRelationCondition = + // doing a mutation + mutating && + // and it's a safe mutate + !unsafeOperation && + // and the current segment is the direct parent (the last one is the mutate itself), + // the relation condition should be preserved and will be converted to a "connect" later + i === context.nestingPath.length - 2; + + if (fkMapping && !shouldPreserveRelationCondition) { + // turn relation condition into foreign key condition, e.g.: + // { user: { id: 1 } } => { userId: 1 } + for (const [r, fk] of Object.entries(fkMapping)) { + currQuery[fk] = visitWhere[r]; + } + + if (i > 0) { + // prepare for the next segment + currQuery[currField.backLink] = {}; + } + } else { + // preserve the original structure + currQuery[currField.backLink] = { ...visitWhere }; + } + currQuery = currQuery[currField.backLink]; + } + currField = field; + } + } + return result; + } + + flattenGeneratedUniqueField(model: string, args: any) { + // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } + const uniqueConstraints = getUniqueConstraints(this.options.modelMeta, model); + if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { + for (const [field, value] of Object.entries(args)) { + if ( + uniqueConstraints[field] && + uniqueConstraints[field].fields.length > 1 && + typeof value === 'object' + ) { + // multi-field unique constraint, flatten it + delete args[field]; + if (value) { + for (const [f, v] of Object.entries(value)) { + args[f] = v; + } + } + } + } + } + } + + validationError(message: string) { + return prismaClientValidationError(this.prisma, this.options.prismaModule, message); + } + + unknownError(message: string) { + return prismaClientUnknownRequestError(this.prisma, this.options.prismaModule, message, { + clientVersion: getVersion(), + }); + } + + getModelFields(model: string) { + return getModelInfo(this.options.modelMeta, model)?.fields; + } + + /** + * Gets information for a specific model field. + */ + getModelField(model: string, field: string) { + return resolveField(this.options.modelMeta, model, field); + } +} diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 4dcfa1c1a..53410a196 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -9,7 +9,7 @@ import { HAS_FIELD_LEVEL_POLICY_FLAG, PRE_UPDATE_VALUE_SELECTOR, } from '../constants'; -import type { DbOperations, PolicyOperationKind, QueryContext } from '../types'; +import type { CrudContract, PolicyOperationKind, QueryContext } from '../types'; /** * Common options for PrismaClient enhancements @@ -24,7 +24,7 @@ export interface CommonEnhancementOptions { /** * Function for getting policy guard with a given context */ -export type PolicyFunc = (context: QueryContext, db: Record) => object; +export type PolicyFunc = (context: QueryContext, db: CrudContract) => object; /** * Function for getting policy guard with a given context diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index e143cacfa..4bcab85a1 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -56,6 +56,14 @@ export type QueryContext = { preValue?: any; }; -export type DbClientContract = Record & { - $transaction: (action: (tx: Record) => Promise, options?: unknown) => Promise; +/** + * Prisma contract for CRUD operations. + */ +export type CrudContract = Record; + +/** + * Prisma contract for database client. + */ +export type DbClientContract = CrudContract & { + $transaction: (action: (tx: CrudContract) => Promise, options?: unknown) => Promise; }; diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 49923284a..3a92d393c 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -85,12 +85,19 @@ export async function loadDocument(fileName: string): Promise { const model = document.parseResult.value as Model; - mergeImportsDeclarations(langiumDocuments, model); + const imported = mergeImportsDeclarations(langiumDocuments, model); + // remove imported documents + await services.shared.workspace.DocumentBuilder.update( + [], + imported.map((m) => m.$document!.uri) + ); validationAfterMerge(model); mergeBaseModel(model, services.references.Linker); + await relinkAll(model, services); + return model; } @@ -151,6 +158,8 @@ export function mergeImportsDeclarations(documents: LangiumDocuments, model: Mod }); model.declarations.push(...importedDeclarations); + + return importedModels; } export async function getPluginDocuments(services: ZModelServices, fileName: string): Promise { @@ -295,3 +304,20 @@ export function getDefaultSchemaLocation() { return path.resolve('schema.zmodel'); } + +async function relinkAll(model: Model, services: ZModelServices) { + const doc = model.$document!; + + // unlink the document + services.references.Linker.unlink(doc); + + // remove current document + await services.shared.workspace.DocumentBuilder.update([], [doc.uri]); + + // recreate the document + const newDoc = services.shared.workspace.LangiumDocumentFactory.fromModel(model, doc.uri); + (model as Mutable).$document = newDoc; + + // rebuild the document + await services.shared.workspace.DocumentBuilder.build([newDoc], { validationChecks: 'all' }); +} diff --git a/packages/schema/src/extension.ts b/packages/schema/src/extension.ts index d28f7dd87..a3e19d7f8 100644 --- a/packages/schema/src/extension.ts +++ b/packages/schema/src/extension.ts @@ -56,6 +56,6 @@ function startLanguageClient(context: vscode.ExtensionContext): LanguageClient { const client = new LanguageClient('zmodel', 'ZenStack Model', serverOptions, clientOptions); // Start the client. This will also launch the server - client.start(); + void client.start(); return client; } diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 33ec0ff37..3e4517444 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -6,14 +6,9 @@ import { isStringLiteral, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { - analyzePolicies, - getLiteral, - getModelFieldsWithBases, - getModelIdFields, - getModelUniqueFields, -} from '@zenstackhq/sdk'; +import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields, isDelegateModel } from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; +import { getModelFieldsWithBases } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getUniqueFields } from '../utils'; @@ -26,7 +21,7 @@ import { validateDuplicatedDeclarations } from './utils'; export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { this.validateBaseAbstractModel(dm, accept); - validateDuplicatedDeclarations(getModelFieldsWithBases(dm), accept); + validateDuplicatedDeclarations(dm, getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); } @@ -224,6 +219,11 @@ export default class DataModelValidator implements AstValidator { return; } + if (field.$container !== contextModel && isDelegateModel(field.$container as DataModel)) { + // relation fields inherited from delegate model don't need opposite relation + return; + } + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const oppositeModel = field.type.reference!.ref! as DataModel; @@ -265,7 +265,7 @@ export default class DataModelValidator implements AstValidator { return; } else if (oppositeFields.length > 1) { oppositeFields - .filter((x) => !x.$inheritedFrom) + .filter((f) => f.$container !== contextModel) .forEach((f) => { if (this.isSelfRelation(f)) { // self relations are partial @@ -368,12 +368,19 @@ export default class DataModelValidator implements AstValidator { private validateBaseAbstractModel(model: DataModel, accept: ValidationAcceptor) { model.superTypes.forEach((superType, index) => { - if (!superType.ref?.isAbstract) - accept('error', `Model ${superType.$refText} cannot be extended because it's not abstract`, { - node: model, - property: 'superTypes', - index, - }); + if ( + !superType.ref?.isAbstract && + !superType.ref?.attributes.some((attr) => attr.decl.ref?.name === '@@delegate') + ) + accept( + 'error', + `Model ${superType.$refText} cannot be extended because it's neither abstract nor marked as "@@delegate"`, + { + node: model, + property: 'superTypes', + index, + } + ); }); } } diff --git a/packages/schema/src/language-server/validator/datasource-validator.ts b/packages/schema/src/language-server/validator/datasource-validator.ts index f24fed08b..d102e409f 100644 --- a/packages/schema/src/language-server/validator/datasource-validator.ts +++ b/packages/schema/src/language-server/validator/datasource-validator.ts @@ -9,7 +9,7 @@ import { SUPPORTED_PROVIDERS } from '../constants'; */ export default class DataSourceValidator implements AstValidator { validate(ds: DataSource, accept: ValidationAcceptor): void { - validateDuplicatedDeclarations(ds.fields, accept); + validateDuplicatedDeclarations(ds, ds.fields, accept); this.validateProvider(ds, accept); this.validateUrl(ds, accept); this.validateRelationMode(ds, accept); diff --git a/packages/schema/src/language-server/validator/enum-validator.ts b/packages/schema/src/language-server/validator/enum-validator.ts index 4223d8a2b..5780d91fb 100644 --- a/packages/schema/src/language-server/validator/enum-validator.ts +++ b/packages/schema/src/language-server/validator/enum-validator.ts @@ -10,7 +10,7 @@ import { validateDuplicatedDeclarations } from './utils'; export default class EnumValidator implements AstValidator { // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types validate(_enum: Enum, accept: ValidationAcceptor) { - validateDuplicatedDeclarations(_enum.fields, accept); + validateDuplicatedDeclarations(_enum, _enum.fields, accept); this.validateAttributes(_enum, accept); _enum.fields.forEach((field) => { this.validateField(field, accept); diff --git a/packages/schema/src/language-server/validator/schema-validator.ts b/packages/schema/src/language-server/validator/schema-validator.ts index b80bf890d..d3722638e 100644 --- a/packages/schema/src/language-server/validator/schema-validator.ts +++ b/packages/schema/src/language-server/validator/schema-validator.ts @@ -13,7 +13,7 @@ export default class SchemaValidator implements AstValidator { constructor(protected readonly documents: LangiumDocuments) {} validate(model: Model, accept: ValidationAcceptor): void { this.validateImports(model, accept); - validateDuplicatedDeclarations(model.declarations, accept); + validateDuplicatedDeclarations(model, model.declarations, accept); const importedModels = resolveTransitiveImports(this.documents, model); diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index 340f471b8..6a1a44336 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -3,7 +3,6 @@ import { AttributeParam, BuiltinType, DataModelAttribute, - DataModelField, DataModelFieldAttribute, Expression, ExpressionType, @@ -21,6 +20,7 @@ import { AstNode, ValidationAcceptor } from 'langium'; * Checks if the given declarations have duplicated names */ export function validateDuplicatedDeclarations( + container: AstNode, decls: Array, accept: ValidationAcceptor ): void { @@ -34,7 +34,7 @@ export function validateDuplicatedDeclarations( if (decls.length > 1) { let errorField = decls[1]; if (isDataModelField(decls[0])) { - const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$inheritedFrom); + const nonInheritedFields = decls.filter((x) => !(isDataModelField(x) && x.$container !== container)); if (nonInheritedFields.length > 0) { errorField = nonInheritedFields.slice(-1)[0]; } diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index 8f60cbe69..5b6a6c95a 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -10,8 +10,8 @@ import { getDocument, } from 'langium'; -import { getModelFieldsWithBases } from '@zenstackhq/sdk'; import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; +import { getModelFieldsWithBases } from '../utils/ast-utils'; import { IssueCodes } from './constants'; import { MissingOppositeRelationData } from './validator/datamodel-validator'; import { ZModelFormatter } from './zmodel-formatter'; diff --git a/packages/schema/src/language-server/zmodel-completion-provider.ts b/packages/schema/src/language-server/zmodel-completion-provider.ts index 742f7087f..e100c870a 100644 --- a/packages/schema/src/language-server/zmodel-completion-provider.ts +++ b/packages/schema/src/language-server/zmodel-completion-provider.ts @@ -159,7 +159,7 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { acceptor(item); }; - super.completionForCrossReference(context, crossRef, customAcceptor); + return super.completionForCrossReference(context, crossRef, customAcceptor); } override completionForKeyword( @@ -174,7 +174,7 @@ export class ZModelCompletionProvider extends DefaultCompletionProvider { } acceptor(item); }; - super.completionForKeyword(context, keyword, customAcceptor); + return super.completionForKeyword(context, keyword, customAcceptor); } private filterKeywordForContext(context: CompletionContext, keyword: string) { diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 30929791f..5ab841c96 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -470,12 +470,12 @@ export class ZModelLinker extends DefaultLinker { } private resolveDataModel(node: DataModel, document: LangiumDocument, extraScopes: ScopeProvider[]) { - if (node.superTypes.length > 0) { - const providers = node.superTypes.map( - (superType) => (name: string) => superType.ref?.fields.find((f) => f.name === name) - ); - extraScopes = [...providers, ...extraScopes]; - } + // if (node.superTypes.length > 0) { + // const providers = node.superTypes.map( + // (superType) => (name: string) => superType.ref?.fields.find((f) => f.name === name) + // ); + // extraScopes = [...providers, ...extraScopes]; + // } return this.resolveDefault(node, document, extraScopes); } diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 21304fa4a..9d685db27 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -16,7 +16,6 @@ import { getModelFieldsWithBases, getRecursiveBases, isAuthInvocation, - isFutureExpr, } from '@zenstackhq/sdk'; import { AstNode, @@ -38,7 +37,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { isCollectionPredicate, resolveImportUri } from '../utils/ast-utils'; +import { isCollectionPredicate, isFutureInvocation, resolveImportUri } from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; /** @@ -76,7 +75,7 @@ export class ZModelScopeComputation extends DefaultScopeComputation { override processNode(node: AstNode, document: LangiumDocument, scopes: PrecomputedScopes) { super.processNode(node, document, scopes); - if (isDataModel(node)) { + if (isDataModel(node) && !node.$baseMerged) { // add base fields to the scope recursively const bases = getRecursiveBases(node); for (const base of bases) { @@ -164,7 +163,7 @@ export class ZModelScopeProvider extends DefaultScopeProvider { // resolve to `User` or `@@auth` model return this.createScopeForAuthModel(node, globalScope); } - if (isFutureExpr(operand)) { + if (isFutureInvocation(operand)) { // resolve `future()` to the containing model return this.createScopeForContainingModel(node, globalScope); } diff --git a/packages/schema/src/plugins/enhancer/delegate/index.ts b/packages/schema/src/plugins/enhancer/delegate/index.ts new file mode 100644 index 000000000..5e4cffdfa --- /dev/null +++ b/packages/schema/src/plugins/enhancer/delegate/index.ts @@ -0,0 +1,16 @@ +import { type PluginOptions } from '@zenstackhq/sdk'; +import type { Model } from '@zenstackhq/sdk/ast'; +import type { Project } from 'ts-morph'; +import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; +import path from 'path'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + const prismaGenerator = new PrismaSchemaGenerator(); + await prismaGenerator.generate(model, { + provider: '@internal', + schemaPath: options.schemaPath, + output: path.join(outDir, 'delegate.prisma'), + overrideClientGenerationPath: path.join(outDir, '.delegate'), + mode: 'logical', + }); +} diff --git a/packages/schema/src/plugins/enhancer/enhance/index.ts b/packages/schema/src/plugins/enhancer/enhance/index.ts new file mode 100644 index 000000000..c33de08b0 --- /dev/null +++ b/packages/schema/src/plugins/enhancer/enhance/index.ts @@ -0,0 +1,215 @@ +import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; +import { + getAttribute, + getDataModels, + getPrismaClientImportSpec, + isDelegateModel, + type PluginOptions, +} from '@zenstackhq/sdk'; +import { DataModelField, isDataModel, isReferenceExpr, type DataModel, type Model } from '@zenstackhq/sdk/ast'; +import path from 'path'; +import { + ForEachDescendantTraversalControl, + MethodSignature, + Node, + Project, + PropertySignature, + SyntaxKind, + TypeAliasDeclaration, +} from 'ts-morph'; +import { PrismaSchemaGenerator } from '../../prisma/schema-generator'; + +export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { + const outFile = path.join(outDir, 'enhance.ts'); + let logicalPrismaClientDir: string | undefined; + + if (hasDelegateModel(model)) { + logicalPrismaClientDir = await generateLogicalPrisma(model, options, outDir); + } + + project.createSourceFile( + outFile, + `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime'; +import modelMeta from './model-meta'; +import policy from './policy'; +${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} +import { Prisma } from '${getPrismaClientImportSpec(model, outDir)}'; +${logicalPrismaClientDir ? `import { PrismaClient as EnhancedPrismaClient } from '${logicalPrismaClientDir}';` : ''} + +export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions) { + return createEnhancement(prisma, { + modelMeta, + policy, + zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), + prismaModule: Prisma, + ...options + }, context)${logicalPrismaClientDir ? ' as EnhancedPrismaClient' : ''}; +} +`, + { overwrite: true } + ); +} + +function hasDelegateModel(model: Model) { + const dataModels = getDataModels(model); + return dataModels.some( + (dm) => isDelegateModel(dm) && dataModels.some((sub) => sub.superTypes.some((base) => base.ref === dm)) + ); +} + +async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) { + const prismaGenerator = new PrismaSchemaGenerator(); + const prismaClientOutDir = './.delegate'; + await prismaGenerator.generate(model, { + provider: '@internal', + schemaPath: options.schemaPath, + output: path.join(outDir, 'delegate.prisma'), + overrideClientGenerationPath: prismaClientOutDir, + mode: 'logical', + }); + + await processClientTypes(model, path.join(outDir, prismaClientOutDir)); + return prismaClientOutDir; +} + +async function processClientTypes(model: Model, prismaClientDir: string) { + const project = new Project(); + const sf = project.addSourceFileAtPath(path.join(prismaClientDir, 'index.d.ts')); + + const delegateModels: [DataModel, DataModel[]][] = []; + model.declarations + .filter((d): d is DataModel => isDelegateModel(d)) + .forEach((dm) => { + delegateModels.push([ + dm, + model.declarations.filter( + (d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm) + ), + ]); + }); + + const toRemove: (PropertySignature | MethodSignature)[] = []; + const toReplaceText: [TypeAliasDeclaration, string][] = []; + + sf.forEachDescendant((desc, traversal) => { + removeAuxRelationFields(desc, toRemove, traversal); + fixDelegateUnionType(desc, delegateModels, toReplaceText, traversal); + removeCreateFromDelegateInputTypes(desc, delegateModels, toRemove, traversal); + removeToplevelCreates(desc, delegateModels, toRemove, traversal); + }); + + toRemove.forEach((n) => n.remove()); + toReplaceText.forEach(([node, text]) => node.replaceWithText(text)); + + await project.save(); +} + +function removeAuxRelationFields( + desc: Node, + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (desc.isKind(SyntaxKind.PropertySignature) || desc.isKind(SyntaxKind.MethodSignature)) { + // remove aux fields + const name = desc.getName(); + + if (name.startsWith(DELEGATE_AUX_RELATION_PREFIX)) { + toRemove.push(desc); + traversal.skip(); + } + } +} + +function fixDelegateUnionType( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toReplaceText: [TypeAliasDeclaration, string][], + traversal: ForEachDescendantTraversalControl +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate, concreteModels]) => { + if (name === `$${delegate.name}Payload`) { + const discriminator = getDiscriminatorField(delegate); + // const discriminator = 'delegateType'; // delegate.fields.find((f) => hasAttribute(f, '@discriminator')); + if (discriminator) { + toReplaceText.push([ + desc, + `export type ${name} = + ${concreteModels + .map((m) => `($${m.name}Payload & { scalars: { ${discriminator.name}: '${m.name}' } })`) + .join(' | ')};`, + ]); + traversal.skip(); + } + } + }); +} + +function removeCreateFromDelegateInputTypes( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (!desc.isKind(SyntaxKind.TypeAliasDeclaration)) { + return; + } + + const name = desc.getName(); + delegateModels.forEach(([delegate]) => { + // remove create related sub-payload from delegate's input types since they cannot be created directly + const regex = new RegExp(`\\${delegate.name}(Unchecked)?(Create|Update).*Input`); + if (regex.test(name)) { + desc.forEachDescendant((d, innerTraversal) => { + if ( + d.isKind(SyntaxKind.PropertySignature) && + ['create', 'upsert', 'connectOrCreate'].includes(d.getName()) + ) { + toRemove.push(d); + innerTraversal.skip(); + } + }); + traversal.skip(); + } + }); +} + +function removeToplevelCreates( + desc: Node, + delegateModels: [DataModel, DataModel[]][], + toRemove: (PropertySignature | MethodSignature)[], + traversal: ForEachDescendantTraversalControl +) { + if (desc.isKind(SyntaxKind.InterfaceDeclaration)) { + // remove create and upsert methods from delegate interfaces since they cannot be created directly + const name = desc.getName(); + if (delegateModels.map(([dm]) => `${dm.name}Delegate`).includes(name)) { + const createMethod = desc.getMethod('create'); + if (createMethod) { + toRemove.push(createMethod); + } + const createManyMethod = desc.getMethod('createMany'); + if (createManyMethod) { + toRemove.push(createManyMethod); + } + const upsertMethod = desc.getMethod('upsert'); + if (upsertMethod) { + toRemove.push(upsertMethod); + } + traversal.skip(); + } + } +} + +function getDiscriminatorField(delegate: DataModel) { + const delegateAttr = getAttribute(delegate, '@@delegate'); + if (!delegateAttr) { + return undefined; + } + const arg = delegateAttr.args[0]?.value; + return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined; +} diff --git a/packages/schema/src/plugins/enhancer/enhancer.ts b/packages/schema/src/plugins/enhancer/enhancer.ts deleted file mode 100644 index 5eccd356d..000000000 --- a/packages/schema/src/plugins/enhancer/enhancer.ts +++ /dev/null @@ -1,29 +0,0 @@ -import { getPrismaClientImportSpec, type PluginOptions } from '@zenstackhq/sdk'; -import type { Model } from '@zenstackhq/sdk/ast'; -import path from 'path'; -import type { Project } from 'ts-morph'; - -export async function generate(model: Model, options: PluginOptions, project: Project, outDir: string) { - const outFile = path.join(outDir, 'enhance.ts'); - - project.createSourceFile( - outFile, - `import { createEnhancement, type EnhancementContext, type EnhancementOptions, type ZodSchemas } from '@zenstackhq/runtime'; -import modelMeta from './model-meta'; -import policy from './policy'; -${options.withZodSchemas ? "import * as zodSchemas from './zod';" : 'const zodSchemas = undefined;'} -import { Prisma } from '${getPrismaClientImportSpec(model, outDir)}'; - -export function enhance(prisma: DbClient, context?: EnhancementContext, options?: EnhancementOptions): DbClient { - return createEnhancement(prisma, { - modelMeta, - policy, - zodSchemas: zodSchemas as unknown as (ZodSchemas | undefined), - prismaModule: Prisma, - ...options - }, context); -} -`, - { overwrite: true } - ); -} diff --git a/packages/schema/src/plugins/enhancer/index.ts b/packages/schema/src/plugins/enhancer/index.ts index 45f3ceb35..86e3ecf39 100644 --- a/packages/schema/src/plugins/enhancer/index.ts +++ b/packages/schema/src/plugins/enhancer/index.ts @@ -7,7 +7,7 @@ import { type PluginFunction, } from '@zenstackhq/sdk'; import { getDefaultOutputFolder } from '../plugin-utils'; -import { generate as generateEnhancer } from './enhancer'; +import { generate as generateEnhancer } from './enhance'; import { generate as generateModelMeta } from './model-meta'; import { generate as generatePolicy } from './policy'; diff --git a/packages/schema/src/plugins/enhancer/model-meta.ts b/packages/schema/src/plugins/enhancer/model-meta/index.ts similarity index 100% rename from packages/schema/src/plugins/enhancer/model-meta.ts rename to packages/schema/src/plugins/enhancer/model-meta/index.ts diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index e38a34c29..9333634fa 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -2,9 +2,11 @@ import { BinaryExpr, BooleanLiteral, DataModel, + DataModelField, Expression, InvocationExpr, isDataModel, + isDataModelField, isEnumField, isMemberAccessExpr, isReferenceExpr, @@ -13,9 +15,11 @@ import { MemberAccessExpr, NumberLiteral, ReferenceExpr, + ReferenceTarget, StringLiteral, UnaryExpr, } from '@zenstackhq/language/ast'; +import { DELEGATE_AUX_RELATION_PREFIX } from '@zenstackhq/runtime'; import { ExpressionContext, getFunctionExpressionContext, @@ -23,12 +27,14 @@ import { getLiteral, isAuthInvocation, isDataModelFieldReference, + isDelegateModel, isFutureExpr, PluginError, TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; +import invariant from 'tiny-invariant'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '..'; @@ -113,11 +119,44 @@ export class ExpressionWriter { throw new Error('We should never get here'); } else { this.block(() => { - this.writer.write(`${expr.target.ref?.name}: true`); + const ref = expr.target.ref; + invariant(ref); + if (this.isFieldReferenceToDelegateModel(ref)) { + const thisModel = ref.$container as DataModel; + const targetBase = ref.$inheritedFrom; + this.writeBaseHierarchy(thisModel, targetBase, () => this.writer.write(`${ref.name}: true`)); + } else { + this.writer.write(`${ref.name}: true`); + } }); } } + private writeBaseHierarchy(thisModel: DataModel, targetBase: DataModel | undefined, conditionWriter: () => void) { + if (!targetBase || thisModel === targetBase) { + conditionWriter(); + return; + } + + const base = this.getDelegateBase(thisModel); + if (!base) { + throw new PluginError(name, `Failed to resolve delegate base model for "${thisModel.name}"`); + } + + this.writer.write(`${`${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`}: `); + this.writer.block(() => { + this.writeBaseHierarchy(base, targetBase, conditionWriter); + }); + } + + private getDelegateBase(model: DataModel) { + return model.superTypes.map((t) => t.ref).filter((t) => t && isDelegateModel(t))?.[0]; + } + + private isFieldReferenceToDelegateModel(ref: ReferenceTarget): ref is DataModelField { + return isDataModelField(ref) && !!ref.$inheritedFrom && isDelegateModel(ref.$inheritedFrom); + } + private writeMemberAccess(expr: MemberAccessExpr) { if (this.isAuthOrAuthMemberAccess(expr)) { // member access of `auth()`, generate plain expression @@ -496,48 +535,67 @@ export class ExpressionWriter { filterOp?: FilterOperators, extraArgs?: Record ) { - let selector: string | undefined; + // let selector: string | undefined; let operand: Expression | undefined; + let fieldWriter: ((conditionWriter: () => void) => void) | undefined; if (isThisExpr(fieldAccess)) { // pass on writeCondition(); return; } else if (isReferenceExpr(fieldAccess)) { - selector = fieldAccess.target.ref?.name; + const ref = fieldAccess.target.ref; + invariant(ref); + if (this.isFieldReferenceToDelegateModel(ref)) { + const thisModel = ref.$container as DataModel; + const targetBase = ref.$inheritedFrom; + fieldWriter = (conditionWriter: () => void) => + this.writeBaseHierarchy(thisModel, targetBase, () => { + this.writer.write(`${ref.name}: `); + conditionWriter(); + }); + } else { + fieldWriter = (conditionWriter: () => void) => { + this.writer.write(`${ref.name}: `); + conditionWriter(); + }; + } } else if (isMemberAccessExpr(fieldAccess)) { - if (isFutureExpr(fieldAccess.operand)) { + if (!isFutureExpr(fieldAccess.operand)) { // future().field should be treated as the "field" - selector = fieldAccess.member.ref?.name; - } else { - selector = fieldAccess.member.ref?.name; operand = fieldAccess.operand; } + fieldWriter = (conditionWriter: () => void) => { + this.writer.write(`${fieldAccess.member.ref?.name}: `); + conditionWriter(); + }; } else { throw new PluginError(name, `Unsupported expression type: ${fieldAccess.$type}`); } - if (!selector) { + if (!fieldWriter) { throw new PluginError(name, `Failed to write FieldAccess expression`); } const writerFilterOutput = () => { - this.writer.write(selector + ': '); - if (filterOp) { - this.block(() => { - this.writer.write(`${filterOp}: `); - writeCondition(); + // this.writer.write(selector + ': '); + fieldWriter!(() => { + if (filterOp) { + this.block(() => { + this.writer.write(`${filterOp}: `); + writeCondition(); - if (extraArgs) { - for (const [k, v] of Object.entries(extraArgs)) { - this.writer.write(`,\n${k}: `); - this.plain(v); + if (extraArgs) { + for (const [k, v] of Object.entries(extraArgs)) { + this.writer.write(`,\n${k}: `); + this.plain(v); + } } - } - }); - } else { - writeCondition(); - } + }); + } else { + writeCondition(); + } + }); }; if (operand) { diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index 149858cd6..2032f2b99 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -1,12 +1,9 @@ import { DataModel, - DataModelAttribute, DataModelField, - DataModelFieldAttribute, Enum, Expression, Model, - isBinaryExpr, isDataModel, isDataModelField, isEnum, @@ -15,7 +12,6 @@ import { isMemberAccessExpr, isReferenceExpr, isThisExpr, - isUnaryExpr, } from '@zenstackhq/language/ast'; import { FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX, @@ -71,7 +67,7 @@ export class PolicyGenerator { sf.addImportDeclaration({ namedImports: [ { name: 'type QueryContext' }, - { name: 'type DbOperations' }, + { name: 'type CrudContract' }, { name: 'allFieldsEqual' }, { name: 'type PolicyDef' }, ], @@ -203,7 +199,7 @@ export class PolicyGenerator { operation: PolicyOperationKind, override = false ) { - const attributes = target.attributes as (DataModelAttribute | DataModelFieldAttribute)[]; + const attributes = target.attributes; const attrName = isDataModel(target) ? `@@${kind}` : `@${kind}`; const attrs = attributes.filter((attr) => { if (attr.decl.ref?.name !== attrName) { @@ -253,30 +249,6 @@ export class PolicyGenerator { } } - private visitPolicyExpression(expr: Expression, postUpdate: boolean): Expression | undefined { - if (isBinaryExpr(expr) && (expr.operator === '&&' || expr.operator === '||')) { - const left = this.visitPolicyExpression(expr.left, postUpdate); - const right = this.visitPolicyExpression(expr.right, postUpdate); - if (!left) return right; - if (!right) return left; - return { ...expr, left, right }; - } - - if (isUnaryExpr(expr) && expr.operator === '!') { - const operand = this.visitPolicyExpression(expr.operand, postUpdate); - if (!operand) return undefined; - return { ...expr, operand }; - } - - if (postUpdate && !this.hasFutureReference(expr)) { - return undefined; - } else if (!postUpdate && this.hasFutureReference(expr)) { - return undefined; - } - - return expr; - } - private hasFutureReference(expr: Expression) { for (const node of streamAst(expr)) { if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { @@ -766,7 +738,7 @@ export class PolicyGenerator { { // for generating field references used by field comparison in the same model name: 'db', - type: 'Record', + type: 'CrudContract', }, ], statements, diff --git a/packages/schema/src/plugins/plugin-utils.ts b/packages/schema/src/plugins/plugin-utils.ts index f4f521fdc..00b806e7e 100644 --- a/packages/schema/src/plugins/plugin-utils.ts +++ b/packages/schema/src/plugins/plugin-utils.ts @@ -35,13 +35,9 @@ export function ensureDefaultOutputFolder(options: PluginRunnerOptions) { name: '.zenstack', version: '1.0.0', exports: { - './model-meta': { - types: './model-meta.ts', - default: './model-meta.js', - }, - './policy': { - types: './policy.d.ts', - default: './policy.js', + './enhance': { + types: './enhance.d.ts', + default: './enhance.js', }, './zod': { types: './zod/index.d.ts', diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index c4b209aa6..b27624cd7 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -1,5 +1,5 @@ import { PluginFunction } from '@zenstackhq/sdk'; -import PrismaSchemaGenerator from './schema-generator'; +import { PrismaSchemaGenerator } from './schema-generator'; export const name = 'Prisma'; export const description = 'Generating Prisma schema'; diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 64777b62e..b65313940 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -110,10 +110,15 @@ export class Model extends ContainerDeclaration { name: string, type: ModelFieldType | string, attributes: (FieldAttribute | PassThroughAttribute)[] = [], - documentations: string[] = [] + documentations: string[] = [], + addToFront = false ): ModelField { const field = new ModelField(name, type, attributes, documentations); - this.fields.push(field); + if (addToFront) { + this.fields.unshift(field); + } else { + this.fields.push(field); + } return field; } diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 0be727d31..01a8efc60 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -16,6 +16,7 @@ import { GeneratorDecl, InvocationExpr, isArrayExpr, + isDataModel, isInvocationExpr, isLiteralExpr, isNullExpr, @@ -27,14 +28,17 @@ import { StringLiteral, } from '@zenstackhq/language/ast'; import { match } from 'ts-pattern'; +import { getIdFields } from '../../utils/ast-utils'; -import { PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; +import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { getAttribute, getDMMF, getLiteral, getPrismaVersion, isAuthInvocation, + isDelegateModel, + isIdField, PluginError, PluginOptions, resolved, @@ -44,15 +48,18 @@ import { import fs from 'fs'; import { writeFile } from 'fs/promises'; import { streamAst } from 'langium'; +import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import semver from 'semver'; import stripColor from 'strip-color'; +import { upperCaseFirst } from 'upper-case-first'; import { name } from '.'; import { getStringLiteral } from '../../language-server/validator/utils'; import telemetry from '../../telemetry'; import { execSync } from '../../utils/exec-utils'; import { findPackageJson } from '../../utils/pkg-utils'; import { + AttributeArgValue, ModelFieldType, AttributeArg as PrismaAttributeArg, AttributeArgValue as PrismaAttributeArgValue, @@ -76,7 +83,7 @@ const FIELD_PASSTHROUGH_ATTR = '@prisma.passthrough'; /** * Generates Prisma schema file */ -export default class PrismaSchemaGenerator { +export class PrismaSchemaGenerator { private zModelGenerator: ZModelCodeGenerator = new ZModelCodeGenerator(); private readonly PRELUDE = `////////////////////////////////////////////////////////////////////////////////////////////// @@ -86,8 +93,13 @@ export default class PrismaSchemaGenerator { `; + private mode: 'logical' | 'physical' = 'physical'; + async generate(model: Model, options: PluginOptions) { const warnings: string[] = []; + if (options.mode) { + this.mode = options.mode as 'logical' | 'physical'; + } const prismaVersion = getPrismaVersion(); if (prismaVersion && semver.lt(prismaVersion, PRISMA_MINIMUM_VERSION)) { @@ -113,7 +125,7 @@ export default class PrismaSchemaGenerator { break; case GeneratorDecl: - this.generateGenerator(prisma, decl as GeneratorDecl); + this.generateGenerator(prisma, decl as GeneratorDecl, options); break; } } @@ -220,7 +232,7 @@ export default class PrismaSchemaGenerator { return JSON.stringify(expr.value); } - private generateGenerator(prisma: PrismaModel, decl: GeneratorDecl) { + private generateGenerator(prisma: PrismaModel, decl: GeneratorDecl, options: PluginOptions) { const generator = prisma.addGenerator( decl.name, decl.fields.map((f) => ({ name: f.name, text: this.configExprToText(f.value) })) @@ -262,13 +274,31 @@ export default class PrismaSchemaGenerator { } } } + + if (typeof options.overrideClientGenerationPath === 'string') { + const output = generator.fields.find((f) => f.name === 'output'); + if (output) { + output.text = JSON.stringify(options.overrideClientGenerationPath); + } else { + generator.fields.push({ + name: 'output', + text: JSON.stringify(options.overrideClientGenerationPath), + }); + } + } } } private generateModel(prisma: PrismaModel, decl: DataModel) { const model = decl.isView ? prisma.addView(decl.name) : prisma.addModel(decl.name); for (const field of decl.fields) { - this.generateModelField(model, field); + if (field.$inheritedFrom) { + if (field.$inheritedFrom.isAbstract || this.mode === 'logical' || isIdField(field)) { + this.generateModelField(model, field); + } + } else { + this.generateModelField(model, field); + } } for (const attr of decl.attributes.filter((attr) => this.isPrismaAttribute(attr))) { @@ -281,6 +311,148 @@ export default class PrismaSchemaGenerator { // user defined comments pass-through decl.comments.forEach((c) => model.addComment(c)); + + // generate relation fields on base models linking to concrete models + this.generateDelegateRelationForBase(model, decl); + + // generate reverse relation fields on concrete models + this.generateDelegateRelationForConcrete(model, decl); + + // expand relations on other models that reference delegated models to concrete models + this.expandPolymorphicRelations(model, decl); + } + + private generateDelegateRelationForBase(model: PrismaDataModel, decl: DataModel) { + if (this.mode !== 'physical') { + return; + } + + if (!isDelegateModel(decl)) { + return; + } + + // collect concrete models inheriting this model + const concreteModels = decl.$container.declarations.filter( + (d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl) + ); + + // generate an optional relation field in delegate base model to each concrete model + concreteModels.forEach((concrete) => { + const auxName = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`; + model.addField(auxName, new ModelFieldType(concrete.name, false, true)); + }); + } + + private generateDelegateRelationForConcrete(model: PrismaDataModel, concreteDecl: DataModel) { + if (this.mode !== 'physical') { + return; + } + + // generate a relation field for each delegated base model + + const baseModels = concreteDecl.superTypes + .map((t) => t.ref) + .filter((t): t is DataModel => !!t) + .filter((t) => isDelegateModel(t)); + + baseModels.forEach((base) => { + const idFields = getIdFields(base); + + // add relation fields + const relationField = `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(base.name)}`; + model.addField(relationField, base.name, [ + new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg( + 'fields', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ) + ), + new PrismaAttributeArg( + 'references', + new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ) + ), + new PrismaAttributeArg( + 'onDelete', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')) + ), + new PrismaAttributeArg( + 'onUpdate', + new AttributeArgValue('FieldReference', new PrismaFieldReference('Cascade')) + ), + ]), + ]); + }); + } + + private expandPolymorphicRelations(model: PrismaDataModel, decl: DataModel) { + if (this.mode !== 'logical') { + return; + } + + // the logical schema needs to expand relations to the delegate models to concrete ones + + // for the given model, find all concrete models that have relation to it, + // and generate an auxiliary opposite relation field + decl.fields.forEach((f) => { + const fieldType = f.type.reference?.ref; + if (!isDataModel(fieldType)) { + return; + } + + // find concrete models that inherit from this field's model type + const concreteModels = decl.$container.declarations.filter( + (d) => isDataModel(d) && isDescendantOf(d, fieldType) + ); + + concreteModels.forEach((concrete) => { + const relationField = model.addField( + `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}`, + new ModelFieldType(concrete.name, f.type.array, f.type.optional) + ); + const relAttr = getAttribute(f, '@relation'); + if (relAttr) { + const fieldsArg = relAttr.args.find((arg) => arg.name === 'fields'); + if (fieldsArg) { + const idFields = getIdFields(fieldType); + idFields.forEach((idField) => { + model.addField( + `${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(concrete.name)}${upperCaseFirst( + idField.name + )}`, + idField.type.type! + ); + }); + + const args = new AttributeArgValue( + 'Array', + idFields.map( + (idField) => + new AttributeArgValue('FieldReference', new PrismaFieldReference(idField.name)) + ) + ); + relationField.attributes.push( + new PrismaFieldAttribute('@relation', [ + new PrismaAttributeArg('fields', args), + new PrismaAttributeArg('references', args), + ]) + ); + } else { + relationField.attributes.push(this.makeFieldAttribute(relAttr as DataModelFieldAttribute)); + } + } + }); + }); } private isPrismaAttribute(attr: DataModelAttribute | DataModelFieldAttribute) { @@ -309,7 +481,7 @@ export default class PrismaSchemaGenerator { } } - private generateModelField(model: PrismaDataModel, field: DataModelField) { + private generateModelField(model: PrismaDataModel, field: DataModelField, addToFront = false) { const fieldType = field.type.type || field.type.reference?.ref?.name || this.getUnsupportedFieldType(field.type); if (!fieldType) { @@ -318,34 +490,42 @@ export default class PrismaSchemaGenerator { const type = new ModelFieldType(fieldType, field.type.array, field.type.optional); - const attributes = this.getAttributesToGenerate(field); + const attributes = field.attributes + .filter((attr) => this.isPrismaAttribute(attr)) + // `@default` with `auth()` is handled outside Prisma + .filter((attr) => !this.isDefaultWithAuth(attr)) + .filter( + (attr) => + // when building physical schema, exclude `@default` for id fields inherited from delegate base + !( + this.mode === 'physical' && + isIdField(field) && + this.isInheritedFromDelegate(field) && + attr.decl.$refText === '@default' + ) + ) + .map((attr) => this.makeFieldAttribute(attr)); const nonPrismaAttributes = field.attributes.filter((attr) => attr.decl.ref && !this.isPrismaAttribute(attr)); const documentations = nonPrismaAttributes.map((attr) => '/// ' + this.zModelGenerator.generate(attr)); - const result = model.addField(field.name, type, attributes, documentations); + const result = model.addField(field.name, type, attributes, documentations, addToFront); // user defined comments pass-through field.comments.forEach((c) => result.addComment(c)); } - private getAttributesToGenerate(field: DataModelField) { - if (this.hasDefaultWithAuth(field)) { - return []; - } - return field.attributes - .filter((attr) => this.isPrismaAttribute(attr)) - .map((attr) => this.makeFieldAttribute(attr)); + private isInheritedFromDelegate(field: DataModelField) { + return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom); } - private hasDefaultWithAuth(field: DataModelField) { - const defaultAttr = getAttribute(field, '@default'); - if (!defaultAttr) { + private isDefaultWithAuth(attr: DataModelFieldAttribute) { + if (attr.decl.ref?.name !== '@default') { return false; } - const expr = defaultAttr.args[0]?.value; + const expr = attr.args[0]?.value; if (!expr) { return false; } @@ -469,6 +649,10 @@ export default class PrismaSchemaGenerator { } } +function isDescendantOf(model: DataModel, superModel: DataModel): boolean { + return model.superTypes.some((s) => s.ref === superModel || isDescendantOf(s.ref!, superModel)); +} + export function getDefaultPrismaOutputFile(schemaPath: string) { // handle override from package.json const pkgJsonPath = findPackageJson(path.dirname(schemaPath)); diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index e95099498..721dee538 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -598,3 +598,13 @@ attribute @prisma.passthrough(_ text: String) * A utility attribute to allow passthrough of arbitrary attribute text to the generated Prisma schema. */ attribute @@prisma.passthrough(_ text: String) + +/** + * Marks a model to be a delegate. Used for implementing polymorphism. + */ +attribute @@delegate(_ discriminator: FieldReference) + +// /** +// * Marks a field to be the discriminator that identifies model's type in a polymorphic hierarchy. +// */ +// attribute @discriminator() diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 1e2850577..2688987a2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,24 +1,29 @@ import { BinaryExpr, DataModel, + DataModelField, Expression, InheritableNode, + isArrayExpr, isBinaryExpr, isDataModel, + isDataModelField, + isInvocationExpr, + isMemberAccessExpr, isModel, + isReferenceExpr, Model, ModelImport, + ReferenceExpr, } from '@zenstackhq/language/ast'; +import { isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, + copyAstNode, CstNode, - GenericAstNode, getContainerOfType, getDocument, - isAstNode, - isReference, LangiumDocuments, - linkContentToContainer, Linker, Mutable, Reference, @@ -41,23 +46,32 @@ type BuildReference = ( export function mergeBaseModel(model: Model, linker: Linker) { const buildReference = linker.buildReference.bind(linker); - model.declarations - .filter((x) => x.$type === 'DataModel') - .forEach((decl) => { - const dataModel = decl as DataModel; + model.declarations.filter(isDataModel).forEach((decl) => { + const dataModel = decl as DataModel; - dataModel.fields = dataModel.superTypes + const bases = getRecursiveBases(dataModel).reverse(); + if (bases.length > 0) { + dataModel.fields = bases // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => superType.ref!.fields) + .flatMap((base) => base.fields) + // don't inherit skip-level fields + .filter((f) => !f.$inheritedFrom) .map((f) => cloneAst(f, dataModel, buildReference)) .concat(dataModel.fields); - dataModel.attributes = dataModel.superTypes + dataModel.attributes = bases // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => superType.ref!.attributes) + .flatMap((base) => base.attributes) + // don't inherit skip-level attributes + .filter((attr) => !attr.$inheritedFrom) + // don't inherit `@@delegate` attribute + .filter((attr) => attr.decl.$refText !== '@@delegate') .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); - }); + } + + dataModel.$baseMerged = true; + }); // remove abstract models model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract)); @@ -73,40 +87,49 @@ function cloneAst( clone.$container = newContainer; clone.$containerProperty = node.$containerProperty; clone.$containerIndex = node.$containerIndex; - clone.$inheritedFrom = getContainerOfType(node, isDataModel); + clone.$inheritedFrom = node.$inheritedFrom ?? getContainerOfType(node, isDataModel); return clone; } -// this function is copied from Langium's ast-utils, but copying $resolvedType as well -function copyAstNode(node: T, buildReference: BuildReference): T { - const copy: GenericAstNode = { $type: node.$type, $resolvedType: node.$resolvedType }; - - for (const [name, value] of Object.entries(node)) { - if (!name.startsWith('$')) { - if (isAstNode(value)) { - copy[name] = copyAstNode(value, buildReference); - } else if (isReference(value)) { - copy[name] = buildReference(copy, name, value.$refNode, value.$refText); - } else if (Array.isArray(value)) { - const copiedArray: unknown[] = []; - for (const element of value) { - if (isAstNode(element)) { - copiedArray.push(copyAstNode(element, buildReference)); - } else if (isReference(element)) { - copiedArray.push(buildReference(copy, name, element.$refNode, element.$refText)); - } else { - copiedArray.push(element); - } - } - copy[name] = copiedArray; - } else { - copy[name] = value; +export function getIdFields(dataModel: DataModel) { + const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => + f.attributes.some((attr) => attr.decl.$refText === '@id') + ); + if (fieldLevelId) { + return [fieldLevelId]; + } else { + // get model level @@id attribute + const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + if (modelIdAttr) { + // get fields referenced in the attribute: @@id([field1, field2]]) + if (!isArrayExpr(modelIdAttr.args[0]?.value)) { + return []; } + const argValue = modelIdAttr.args[0].value; + return argValue.items + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .map((expr) => expr.target.ref as DataModelField); } } + return []; +} + +export function isAuthInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); +} - linkContentToContainer(copy); - return copy as unknown as T; +export function isFutureInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + +export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { + if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { + return expr.target.ref; + } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + return expr.member.ref; + } else { + return undefined; + } } export function resolveImportUri(imp: ModelImport): URI | undefined { @@ -183,3 +206,23 @@ export function getContainingDataModel(node: Expression): DataModel | undefined } return undefined; } + +export function getModelFieldsWithBases(model: DataModel) { + if (model.$baseMerged) { + return model.fields; + } else { + return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; + } +} + +export function getRecursiveBases(dataModel: DataModel): DataModel[] { + const result: DataModel[] = []; + dataModel.superTypes.forEach((superType) => { + const baseDecl = superType.ref; + if (baseDecl) { + result.push(baseDecl); + result.push(...getRecursiveBases(baseDecl)); + } + }); + return result; +} diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index d2f425e53..67ba27f99 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -5,7 +5,7 @@ import fs from 'fs'; import path from 'path'; import tmp from 'tmp'; import { loadDocument } from '../../src/cli/cli-util'; -import PrismaSchemaGenerator from '../../src/plugins/prisma/schema-generator'; +import { PrismaSchemaGenerator } from '../../src/plugins/prisma/schema-generator'; import { execSync } from '../../src/utils/exec-utils'; import { loadModel } from '../utils'; @@ -364,6 +364,7 @@ describe('Prisma generator test', () => { output: name, generateClient: false, }); + console.log('Generated:', name); const content = fs.readFileSync(name, 'utf-8'); const dmmf = await getDMMF({ datamodel: content }); @@ -372,9 +373,7 @@ describe('Prisma generator test', () => { const post = dmmf.datamodel.models[0]; expect(post.name).toBe('Post'); expect(post.fields.length).toBe(5); - expect(post.fields[0].name).toBe('id'); - expect(post.fields[3].name).toBe('title'); - expect(post.fields[4].name).toBe('published'); + expect(post.fields.map((f) => f.name)).toEqual(expect.arrayContaining(['id', 'title', 'published'])); }); it('abstract multi files', async () => { diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index e1f06d268..ec3be8f36 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -632,7 +632,9 @@ describe('Data Model Validation Tests', () => { `); expect(errors.length).toBe(1); - expect(errors[0]).toEqual(`Model A cannot be extended because it's not abstract`); + expect(errors[0]).toEqual( + `Model A cannot be extended because it's neither abstract nor marked as "@@delegate"` + ); // relation incomplete from multiple level inheritance expect( diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 9beda653a..cd516f5ec 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -19,11 +19,13 @@ import { ExpressionContext, getAttribute, getAttributeArg, + getAttributeArgLiteral, getAttributeArgs, getAuthModel, getDataModels, getLiteral, hasAttribute, + isDelegateModel, isAuthInvocation, isEnumFieldReference, isForeignKeyField, @@ -57,133 +59,202 @@ function generateModelMetadata( options: ModelMetaGeneratorOptions ) { writer.block(() => { - writer.write('fields:'); - writer.block(() => { - for (const model of dataModels) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.block(() => { - for (const f of model.fields) { - const backlink = getBackLink(f); - const fkMapping = generateForeignKeyMapping(f); - writer.write(`${f.name}: { - name: "${f.name}", - type: "${ - f.type.reference - ? f.type.reference.$refText - : // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - f.type.type! - }",`); - - if (isIdField(f)) { - writer.write(` - isId: true,`); - } - - if (isDataModel(f.type.reference?.ref)) { - writer.write(` - isDataModel: true,`); - } - - if (f.type.array) { - writer.write(` - isArray: true,`); - } - - if (f.type.optional) { - writer.write(` - isOptional: true,`); - } - - if (options.generateAttributes) { - const attrs = getFieldAttributes(f); - if (attrs.length > 0) { - writer.write(` - attributes: ${JSON.stringify(attrs)},`); - } - } else { - // only include essential attributes - const attrs = getFieldAttributes(f).filter((attr) => - ['@default', '@updatedAt'].includes(attr.name) - ); - if (attrs.length > 0) { - writer.write(` - attributes: ${JSON.stringify(attrs)},`); - } - } - - if (backlink) { - writer.write(` - backLink: '${backlink.name}',`); - } - - if (isRelationOwner(f, backlink)) { - writer.write(` - isRelationOwner: true,`); - } - - if (isForeignKeyField(f)) { - writer.write(` - isForeignKey: true,`); - } - - if (fkMapping && Object.keys(fkMapping).length > 0) { - writer.write(` - foreignKeyMapping: ${JSON.stringify(fkMapping)},`); - } - - const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); - if (defaultValueProvider) { - writer.write(` - defaultValueProvider: ${defaultValueProvider},`); - } - - if (isAutoIncrement(f)) { - writer.write(` - isAutoIncrement: true,`); - } - - writer.write(` - },`); - } - }); - writer.write(','); + writeModels(sourceFile, writer, dataModels, options); + writeDeleteCascade(writer, dataModels); + writeAuthModel(writer, dataModels); + }); +} + +function writeModels( + sourceFile: SourceFile, + writer: CodeBlockWriter, + dataModels: DataModel[], + options: ModelMetaGeneratorOptions +) { + writer.write('models:'); + writer.block(() => { + for (const model of dataModels) { + writer.write(`${lowerCaseFirst(model.name)}:`); + writer.block(() => { + writer.write(`name: '${model.name}',`); + writeBaseTypes(writer, model); + writeFields(sourceFile, writer, model, options); + writeUniqueConstraints(writer, model); + if (options.generateAttributes) { + writeModelAttributes(writer, model); + } + writeDiscriminator(writer, model); + }); + writer.writeLine(','); + } + }); + writer.writeLine(','); +} + +function writeBaseTypes(writer: CodeBlockWriter, model: DataModel) { + if (model.superTypes.length > 0) { + writer.write('baseTypes: ['); + writer.write(model.superTypes.map((t) => `'${t.ref?.name}'`).join(', ')); + writer.write('],'); + } +} + +function writeAuthModel(writer: CodeBlockWriter, dataModels: DataModel[]) { + const authModel = getAuthModel(dataModels); + if (authModel) { + writer.writeLine(`authModel: '${authModel.name}'`); + } +} + +function writeDeleteCascade(writer: CodeBlockWriter, dataModels: DataModel[]) { + writer.write('deleteCascade:'); + writer.block(() => { + for (const model of dataModels) { + const cascades = getDeleteCascades(model); + if (cascades.length > 0) { + writer.writeLine(`${lowerCaseFirst(model.name)}: [${cascades.map((n) => `'${n}'`).join(', ')}],`); } - }); - writer.write(','); + } + }); + writer.writeLine(','); +} +function writeUniqueConstraints(writer: CodeBlockWriter, model: DataModel) { + const constraints = getUniqueConstraints(model); + if (constraints.length > 0) { writer.write('uniqueConstraints:'); writer.block(() => { - for (const model of dataModels) { - writer.write(`${lowerCaseFirst(model.name)}:`); - writer.block(() => { - for (const constraint of getUniqueConstraints(model)) { - writer.write(`${constraint.name}: { - name: "${constraint.name}", - fields: ${JSON.stringify(constraint.fields)} - },`); - } - }); - writer.write(','); + for (const constraint of constraints) { + writer.write(`${constraint.name}: { + name: "${constraint.name}", + fields: ${JSON.stringify(constraint.fields)} + },`); } }); writer.write(','); + } +} - writer.write('deleteCascade:'); - writer.block(() => { - for (const model of dataModels) { - const cascades = getDeleteCascades(model); - if (cascades.length > 0) { - writer.writeLine(`${lowerCaseFirst(model.name)}: [${cascades.map((n) => `'${n}'`).join(', ')}],`); +function writeModelAttributes(writer: CodeBlockWriter, model: DataModel) { + const attrs = getAttributes(model); + if (attrs.length > 0) { + writer.write(` +attributes: ${JSON.stringify(attrs)},`); + } +} + +function writeDiscriminator(writer: CodeBlockWriter, model: DataModel) { + const delegateAttr = getAttribute(model, '@@delegate'); + if (!delegateAttr) { + return; + } + const discriminator = getAttributeArg(delegateAttr, 'discriminator') as ReferenceExpr; + if (!discriminator) { + return; + } + if (discriminator) { + writer.write(`discriminator: ${JSON.stringify(discriminator.target.$refText)},`); + } +} + +function writeFields( + sourceFile: SourceFile, + writer: CodeBlockWriter, + model: DataModel, + options: ModelMetaGeneratorOptions +) { + writer.write('fields:'); + writer.block(() => { + for (const f of model.fields) { + const backlink = getBackLink(f); + const fkMapping = generateForeignKeyMapping(f); + writer.write(`${f.name}: {`); + + writer.write(` + name: "${f.name}", + type: "${ + f.type.reference + ? f.type.reference.$refText + : // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + f.type.type! + }",`); + + if (isIdField(f)) { + writer.write(` + isId: true,`); + } + + if (isDataModel(f.type.reference?.ref)) { + writer.write(` + isDataModel: true,`); + } + + if (f.type.array) { + writer.write(` + isArray: true,`); + } + + if (f.type.optional) { + writer.write(` + isOptional: true,`); + } + + if (options.generateAttributes) { + const attrs = getAttributes(f); + if (attrs.length > 0) { + writer.write(` + attributes: ${JSON.stringify(attrs)},`); + } + } else { + // only include essential attributes + const attrs = getAttributes(f).filter((attr) => ['@default', '@updatedAt'].includes(attr.name)); + if (attrs.length > 0) { + writer.write(` + attributes: ${JSON.stringify(attrs)},`); } } - }); - writer.write(','); - const authModel = getAuthModel(dataModels); - if (authModel) { - writer.writeLine(`authModel: '${authModel.name}'`); + if (backlink) { + writer.write(` + backLink: '${backlink.name}',`); + } + + if (isRelationOwner(f, backlink)) { + writer.write(` + isRelationOwner: true,`); + } + + if (isForeignKeyField(f)) { + writer.write(` + isForeignKey: true,`); + } + + if (fkMapping && Object.keys(fkMapping).length > 0) { + writer.write(` + foreignKeyMapping: ${JSON.stringify(fkMapping)},`); + } + + const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); + if (defaultValueProvider) { + writer.write(` + defaultValueProvider: ${defaultValueProvider},`); + } + + if (f.$inheritedFrom && isDelegateModel(f.$inheritedFrom) && !isIdField(f)) { + writer.write(` + inheritedFrom: ${JSON.stringify(f.$inheritedFrom.name)},`); + } + + if (isAutoIncrement(f)) { + writer.write(` + isAutoIncrement: true,`); + } + + writer.write(` + },`); } }); + writer.write(','); } function getBackLink(field: DataModelField) { @@ -212,13 +283,15 @@ function getBackLink(field: DataModelField) { } function getRelationName(field: DataModelField) { - const relAttr = field.attributes.find((attr) => attr.decl.ref?.name === 'relation'); - const relName = relAttr && relAttr.args?.[0] && getLiteral(relAttr.args?.[0].value); - return relName; + const relAttr = getAttribute(field, '@relation'); + if (!relAttr) { + return undefined; + } + return getAttributeArgLiteral(relAttr, 'name'); } -function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { - return field.attributes +function getAttributes(target: DataModelField | DataModel): RuntimeAttribute[] { + return target.attributes .map((attr) => { const args: Array<{ name?: string; value: unknown }> = []; for (const arg of attr.args) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index ed841dbc7..0bd98e63e 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -178,7 +178,7 @@ export function isDataModelFieldReference(node: AstNode): node is ReferenceExpr * Gets `@@id` fields declared at the data model level */ export function getModelIdFields(model: DataModel) { - const idAttr = model.attributes.find((attr) => attr.decl.ref?.name === '@@id'); + const idAttr = model.attributes.find((attr) => attr.decl.$refText === '@@id'); if (!idAttr) { return []; } @@ -196,7 +196,7 @@ export function getModelIdFields(model: DataModel) { * Gets `@@unique` fields declared at the data model level */ export function getModelUniqueFields(model: DataModel) { - const uniqueAttr = model.attributes.find((attr) => attr.decl.ref?.name === '@@unique'); + const uniqueAttr = model.attributes.find((attr) => attr.decl.$refText === '@@unique'); if (!uniqueAttr) { return []; } @@ -379,6 +379,10 @@ export function getAuthModel(dataModels: DataModel[]) { return authModel; } +export function isDelegateModel(node: AstNode) { + return isDataModel(node) && hasAttribute(node, '@@delegate'); +} + export function getIdFields(dataModel: DataModel) { const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') diff --git a/packages/server/src/api/rest/index.ts b/packages/server/src/api/rest/index.ts index 88b463c80..52d700c63 100644 --- a/packages/server/src/api/rest/index.ts +++ b/packages/server/src/api/rest/index.ts @@ -956,7 +956,7 @@ class RequestHandler extends APIHandlerBase { private buildTypeMap(logger: LoggerConfig | undefined, modelMeta: ModelMeta): void { this.typeMap = {}; - for (const [model, fields] of Object.entries(modelMeta.fields)) { + for (const [model, { fields }] of Object.entries(modelMeta.models)) { const idFields = getIdFields(modelMeta, model); if (idFields.length === 0) { logWarning(logger, `Not including model ${model} in the API because it has no ID field`); @@ -1013,7 +1013,7 @@ class RequestHandler extends APIHandlerBase { this.serializers = new Map(); const linkers: Record> = {}; - for (const model of Object.keys(modelMeta.fields)) { + for (const model of Object.keys(modelMeta.models)) { const ids = getIdFields(modelMeta, model); if (ids.length !== 1) { continue; @@ -1027,7 +1027,7 @@ class RequestHandler extends APIHandlerBase { linkers[model] = linker; let projection: Record | null = {}; - for (const [field, fieldMeta] of Object.entries(modelMeta.fields[model])) { + for (const [field, fieldMeta] of Object.entries(modelMeta.models[model].fields)) { if (fieldMeta.isDataModel) { projection[field] = 0; } @@ -1049,14 +1049,14 @@ class RequestHandler extends APIHandlerBase { } // set relators - for (const model of Object.keys(modelMeta.fields)) { + for (const model of Object.keys(modelMeta.models)) { const serializer = this.serializers.get(model); if (!serializer) { continue; } const relators: Record> = {}; - for (const [field, fieldMeta] of Object.entries(modelMeta.fields[model])) { + for (const [field, fieldMeta] of Object.entries(modelMeta.models[model].fields)) { if (!fieldMeta.isDataModel) { continue; } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 2b501abc0..bd64d6461 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -2,7 +2,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import type { DMMF } from '@prisma/generator-helper'; import type { Model } from '@zenstackhq/language/ast'; -import type { AuthUser, DbOperations, EnhancementOptions } from '@zenstackhq/runtime'; +import type { AuthUser, CrudContract, EnhancementKind, EnhancementOptions } from '@zenstackhq/runtime'; import { getDMMF } from '@zenstackhq/sdk'; import { execSync } from 'child_process'; import * as fs from 'fs'; @@ -24,7 +24,7 @@ import prismaPlugin from 'zenstack/plugins/prisma'; */ export const FILE_SPLITTER = '#FILE_SPLITTER#'; -export type FullDbClientContract = Record & { +export type FullDbClientContract = CrudContract & { $on(eventType: any, callback: (event: any) => void): void; $use(cb: any): void; $disconnect: () => Promise; @@ -81,7 +81,6 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } plugin enhancer { @@ -111,6 +110,8 @@ export type SchemaLoadOptions = { dbUrl?: string; pulseApiKey?: string; getPrismaOnly?: boolean; + enhancements?: EnhancementKind[]; + enhanceOptions?: Partial; }; const defaultOptions: SchemaLoadOptions = { @@ -283,8 +284,9 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { modelMeta, zodSchemas, logPrismaQuery: opt.logPrismaQuery, - transactionTimeout: 10000, - ...options, + transactionTimeout: 1000000, + kinds: opt.enhancements, + ...(options ?? opt.enhanceOptions), } ), enhanceRaw: enhance, diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 74426b3ad..9c8f28205 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -400,12 +400,18 @@ importers: change-case: specifier: ^4.1.2 version: 4.1.2 + colors: + specifier: 1.4.0 + version: 1.4.0 decimal.js: specifier: ^10.4.2 version: 10.4.2 deepcopy: specifier: ^2.1.0 version: 2.1.0 + deepmerge: + specifier: ^4.3.1 + version: 4.3.1 lower-case-first: specifier: ^2.0.2 version: 2.0.2 diff --git a/tests/integration/package.json b/tests/integration/package.json index 40627f354..cace90307 100644 --- a/tests/integration/package.json +++ b/tests/integration/package.json @@ -5,7 +5,7 @@ "main": "index.js", "scripts": { "lint": "eslint . --ext .ts", - "test": "ZENSTACK_TEST=1 jest" + "test": "ZENSTACK_TEST=1 jest --runInBand" }, "keywords": [], "author": "", diff --git a/tests/integration/tests/cli/init.test.ts b/tests/integration/tests/cli/init.test.ts index 987752bd2..6b5ae7c3a 100644 --- a/tests/integration/tests/cli/init.test.ts +++ b/tests/integration/tests/cli/init.test.ts @@ -9,7 +9,8 @@ import { createProgram } from '../../../../packages/schema/src/cli'; import { execSync } from '../../../../packages/schema/src/utils/exec-utils'; import { createNpmrc } from './share'; -describe('CLI init command tests', () => { +// eslint-disable-next-line jest/no-disabled-tests +describe.skip('CLI init command tests', () => { let origDir: string; beforeEach(() => { @@ -23,6 +24,7 @@ describe('CLI init command tests', () => { process.chdir(origDir); }); + // eslint-disable-next-line jest/no-disabled-tests it('init project t3 npm std', async () => { execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { stdio: 'inherit', @@ -42,9 +44,7 @@ describe('CLI init command tests', () => { checkDependency('@zenstackhq/runtime', false, true); }); - // Disabled because it blows up memory on MAC, not sure why ... - // eslint-disable-next-line jest/no-disabled-tests - it.skip('init project t3 yarn std', async () => { + it('init project t3 yarn std', async () => { execSync('npx --yes create-t3-app@latest --prisma --CI --noGit .', { stdio: 'inherit', env: { diff --git a/tests/integration/tests/enhancements/with-delegate/policy.test.ts b/tests/integration/tests/enhancements/with-delegate/policy.test.ts new file mode 100644 index 000000000..d0316595d --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/policy.test.ts @@ -0,0 +1,217 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Polymorphic Policy Test', () => { + it('simple boolean', async () => { + const booleanCondition = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + + @@delegate(assetType) + @@allow('create', viewCount >= 0) + @@deny('read', !published) + @@allow('read', true) + @@deny('all', owner.banned) + } + + model Video extends Asset { + watched Boolean @default(false) + videoType String + + @@delegate(videoType) + @@deny('read', !watched) + @@allow('read', true) + } + + model RatedVideo extends Video { + rated Boolean @default(false) + @@deny('read', !rated) + @@allow('read', true) + } + `; + + const booleanExpression = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + + @@delegate(assetType) + @@allow('create', viewCount >= 0) + @@deny('read', published == false) + @@allow('read', true) + @@deny('all', owner.banned == true) + } + + model Video extends Asset { + watched Boolean @default(false) + videoType String + + @@delegate(videoType) + @@deny('read', watched == false) + @@allow('read', true) + } + + model RatedVideo extends Video { + rated Boolean @default(false) + @@deny('read', rated == false) + @@allow('read', true) + } + `; + + for (const schema of [booleanCondition, booleanExpression]) { + const { enhanceRaw: enhance, prisma } = await loadSchema(schema); + + const fullDb = enhance(prisma, undefined, { kinds: ['delegate'], logPrismaQuery: true }); + + const user = await fullDb.user.create({ data: { id: 1 } }); + const userDb = enhance( + prisma, + { user: { id: user.id } }, + { kinds: ['delegate', 'policy'], logPrismaQuery: true } + ); + + // violating Asset create + await expect( + userDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: -1 }, + }) + ).toBeRejectedByPolicy(); + + let video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } } }, + }); + // violating all three layer read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, published: true }, + }); + // violating Video && RatedVideo read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, published: true, watched: true }, + }); + // violating RatedVideo read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + video = await fullDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, rated: true, watched: true, published: true }, + }); + // meeting all read conditions + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveTruthy(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveTruthy(); + + // ban the user + await prisma.user.update({ where: { id: user.id }, data: { banned: true } }); + + // banned user can't read + await expect(userDb.asset.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.video.findUnique({ where: { id: video.id } })).toResolveNull(); + await expect(userDb.ratedVideo.findUnique({ where: { id: video.id } })).toResolveNull(); + + // banned user can't create + await expect( + userDb.ratedVideo.create({ + data: { owner: { connect: { id: user.id } } }, + }) + ).toBeRejectedByPolicy(); + } + }); + + it('interaction with updateMany/deleteMany', async () => { + const schema = ` + model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + banned Boolean @default(false) + + @@allow('all', true) + } + + model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + published Boolean @default(false) + owner User @relation(fields: [ownerId], references: [id]) + ownerId Int + assetType String + viewCount Int @default(0) + version Int @default(0) + + @@delegate(assetType) + @@deny('update', viewCount > 0) + @@deny('delete', viewCount > 0) + @@allow('all', true) + } + + model Video extends Asset { + watched Boolean @default(false) + + @@deny('update', watched) + @@deny('delete', watched) + } + `; + + const { enhance } = await loadSchema(schema, { + logPrismaQuery: true, + }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + const vid1 = await db.video.create({ + data: { watched: false, viewCount: 0, owner: { connect: { id: user.id } } }, + }); + const vid2 = await db.video.create({ + data: { watched: true, viewCount: 1, owner: { connect: { id: user.id } } }, + }); + + await expect(db.asset.updateMany({ data: { version: { increment: 1 } } })).resolves.toMatchObject({ + count: 1, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toMatchObject({ version: 1 }); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).resolves.toMatchObject({ version: 0 }); + + await expect(db.asset.deleteMany()).resolves.toMatchObject({ + count: 1, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).toResolveNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + }); +}); diff --git a/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts new file mode 100644 index 000000000..0d0b24ca2 --- /dev/null +++ b/tests/integration/tests/enhancements/with-delegate/polymorphism.test.ts @@ -0,0 +1,1015 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import { PrismaErrorCode } from '@zenstackhq/runtime'; + +describe('Polymorphism Test', () => { + const schema = ` +model User { + id Int @id @default(autoincrement()) + level Int @default(0) + assets Asset[] + ratedVideos RatedVideo[] @relation('direct') + + @@allow('all', true) +} + +model Asset { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + viewCount Int @default(0) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId Int? + assetType String + + @@delegate(assetType) + @@allow('all', true) +} + +model Video extends Asset { + duration Int + url String + videoType String + + @@delegate(videoType) +} + +model RatedVideo extends Video { + rating Int + user User? @relation(name: 'direct', fields: [userId], references: [id]) + userId Int? +} + +model Image extends Asset { + format String + gallery Gallery? @relation(fields: [galleryId], references: [id]) + galleryId Int? +} + +model Gallery { + id Int @id @default(autoincrement()) + images Image[] +} +`; + + async function setup() { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + const video = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + + const videoWithOwner = await db.ratedVideo.findUnique({ where: { id: video.id }, include: { owner: true } }); + + return { db, video, user, videoWithOwner }; + } + + it('create hierarchy', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + const video = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + include: { owner: true }, + }); + + expect(video).toMatchObject({ + viewCount: 1, + duration: 100, + url: 'xyz', + rating: 100, + assetType: 'Video', + videoType: 'RatedVideo', + owner: user, + }); + + await expect(db.asset.create({ data: { type: 'Video' } })).rejects.toThrow('is a delegate'); + await expect(db.video.create({ data: { type: 'RatedVideo' } })).rejects.toThrow('is a delegate'); + + const image = await db.image.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + expect(image).toMatchObject({ + viewCount: 1, + format: 'png', + assetType: 'Image', + owner: user, + }); + + // create in a nested payload + const gallery = await db.gallery.create({ + data: { + images: { + create: [ + { owner: { connect: { id: user.id } }, format: 'png', viewCount: 1 }, + { owner: { connect: { id: user.id } }, format: 'jpg', viewCount: 2 }, + ], + }, + }, + include: { images: { include: { owner: true } } }, + }); + expect(gallery.images).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + format: 'png', + assetType: 'Image', + viewCount: 1, + owner: user, + }), + expect.objectContaining({ + format: 'jpg', + assetType: 'Image', + viewCount: 2, + owner: user, + }), + ]) + ); + }); + + it('create with base all defaults', async () => { + const { enhance } = await loadSchema( + ` + model Base { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + type String + + @@delegate(type) + } + + model Foo extends Base { + name String + } + `, + { logPrismaQuery: true, enhancements: ['delegate'] } + ); + + const db = enhance(); + const r = await db.foo.create({ data: { name: 'foo' } }); + expect(r).toMatchObject({ name: 'foo', type: 'Foo', id: expect.any(Number), createdAt: expect.any(Date) }); + }); + + it('create with nesting', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + // nested create a relation from base + await expect( + db.ratedVideo.create({ + data: { owner: { create: { id: 2 } }, url: 'xyz', rating: 200, duration: 200 }, + include: { owner: true }, + }) + ).resolves.toMatchObject({ owner: { id: 2 } }); + }); + + it('read with concrete', async () => { + const { db, user, video } = await setup(); + + // find with include + let found = await db.ratedVideo.findFirst({ include: { owner: true } }); + expect(found).toMatchObject(video); + expect(found.owner).toMatchObject(user); + + // find with select + found = await db.ratedVideo.findFirst({ select: { id: true, createdAt: true, url: true, rating: true } }); + expect(found).toMatchObject({ id: video.id, createdAt: video.createdAt, url: video.url, rating: video.rating }); + + // findFirstOrThrow + found = await db.ratedVideo.findFirstOrThrow(); + expect(found).toMatchObject(video); + await expect( + db.ratedVideo.findFirstOrThrow({ + where: { id: video.id + 1 }, + }) + ).rejects.toThrow(); + + // findUnique + found = await db.ratedVideo.findUnique({ + where: { id: video.id }, + }); + expect(found).toMatchObject(video); + + // findUniqueOrThrow + found = await db.ratedVideo.findUniqueOrThrow({ + where: { id: video.id }, + }); + expect(found).toMatchObject(video); + await expect( + db.ratedVideo.findUniqueOrThrow({ + where: { id: video.id + 1 }, + }) + ).rejects.toThrow(); + + // findMany + let items = await db.ratedVideo.findMany(); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject(video); + + // findMany not found + items = await db.ratedVideo.findMany({ where: { id: video.id + 1 } }); + expect(items).toHaveLength(0); + + // findMany with select + items = await db.ratedVideo.findMany({ select: { id: true, createdAt: true, url: true, rating: true } }); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject({ + id: video.id, + createdAt: video.createdAt, + url: video.url, + rating: video.rating, + }); + + // find with base filter + found = await db.ratedVideo.findFirst({ where: { viewCount: video.viewCount } }); + expect(found).toMatchObject(video); + found = await db.ratedVideo.findFirst({ where: { url: video.url, owner: { id: user.id } } }); + expect(found).toMatchObject(video); + + // image: single inheritance + const image = await db.image.create({ + data: { owner: { connect: { id: 1 } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + const readImage = await db.image.findFirst({ include: { owner: true } }); + expect(readImage).toMatchObject(image); + expect(readImage.owner).toMatchObject(user); + }); + + it('read with base', async () => { + const { db, user, video: r } = await setup(); + + let video = await db.video.findFirst({ where: { duration: r.duration }, include: { owner: true } }); + expect(video).toMatchObject({ + id: video.id, + createdAt: r.createdAt, + viewCount: r.viewCount, + url: r.url, + duration: r.duration, + assetType: 'Video', + videoType: 'RatedVideo', + }); + expect(video.rating).toBeUndefined(); + expect(video.owner).toMatchObject(user); + + const asset = await db.asset.findFirst({ where: { viewCount: r.viewCount }, include: { owner: true } }); + expect(asset).toMatchObject({ id: r.id, createdAt: r.createdAt, assetType: 'Video', viewCount: r.viewCount }); + expect(asset.url).toBeUndefined(); + expect(asset.duration).toBeUndefined(); + expect(asset.rating).toBeUndefined(); + expect(asset.videoType).toBeUndefined(); + expect(asset.owner).toMatchObject(user); + + const image = await db.image.create({ + data: { owner: { connect: { id: 1 } }, viewCount: 1, format: 'png' }, + include: { owner: true }, + }); + const imgAsset = await db.asset.findFirst({ where: { assetType: 'Image' }, include: { owner: true } }); + expect(imgAsset).toMatchObject({ + id: image.id, + createdAt: image.createdAt, + assetType: 'Image', + viewCount: image.viewCount, + }); + expect(imgAsset.format).toBeUndefined(); + expect(imgAsset.owner).toMatchObject(user); + }); + + it('update simple', async () => { + const { db, videoWithOwner: video } = await setup(); + + // update with concrete + let updated = await db.ratedVideo.update({ + where: { id: video.id }, + data: { rating: 200 }, + include: { owner: true }, + }); + expect(updated.rating).toBe(200); + expect(updated.owner).toBeTruthy(); + + // update with base + updated = await db.video.update({ + where: { id: video.id }, + data: { duration: 200 }, + select: { duration: true, createdAt: true }, + }); + expect(updated.duration).toBe(200); + expect(updated.createdAt).toBeTruthy(); + + // update with base + updated = await db.asset.update({ + where: { id: video.id }, + data: { viewCount: 200 }, + }); + expect(updated.viewCount).toBe(200); + + // set discriminator + await expect(db.ratedVideo.update({ where: { id: video.id }, data: { assetType: 'Image' } })).rejects.toThrow( + 'is a discriminator' + ); + await expect( + db.ratedVideo.update({ where: { id: video.id }, data: { videoType: 'RatedVideo' } }) + ).rejects.toThrow('is a discriminator'); + }); + + it('update nested create', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // create delegate not allowed + await expect( + db.user.update({ + where: { id: user.id }, + data: { + assets: { + create: { viewCount: 1 }, + }, + }, + include: { assets: true }, + }) + ).rejects.toThrow('is a delegate'); + + // create concrete + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + create: { + viewCount: 1, + duration: 100, + url: 'xyz', + rating: 100, + owner: { connect: { id: user.id } }, + }, + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([ + expect.objectContaining({ viewCount: 1, duration: 100, url: 'xyz', rating: 100 }), + ]), + }); + + // nested create a relation from base + const newVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + await expect( + db.ratedVideo.update({ + where: { id: newVideo.id }, + data: { owner: { create: { id: 2 } }, url: 'xyz', duration: 200, rating: 200 }, + include: { owner: true }, + }) + ).resolves.toMatchObject({ owner: { id: 2 } }); + }); + + it('update nested updateOne', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // update + let updated = await db.asset.update({ + where: { id: video.id }, + data: { owner: { update: { level: 1 } } }, + include: { owner: true }, + }); + expect(updated.owner.level).toBe(1); + + updated = await db.video.update({ + where: { id: video.id }, + data: { duration: 300, owner: { update: { level: 2 } } }, + include: { owner: true }, + }); + expect(updated.duration).toBe(300); + expect(updated.owner.level).toBe(2); + + updated = await db.ratedVideo.update({ + where: { id: video.id }, + data: { rating: 300, owner: { update: { level: 3 } } }, + include: { owner: true }, + }); + expect(updated.rating).toBe(300); + expect(updated.owner.level).toBe(3); + }); + + it('update nested updateMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // updateMany + await db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + create: { url: 'xyz', duration: 111, rating: 222, owner: { connect: { id: user.id } } }, + }, + }, + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { updateMany: { where: { duration: 111 }, data: { rating: 333 } } } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ ratedVideos: expect.arrayContaining([expect.objectContaining({ rating: 333 })]) }); + }); + + it('update nested deleteOne', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // delete with base + await db.user.update({ + where: { id: user.id }, + data: { assets: { delete: { id: video.id } } }, + }); + await expect(db.asset.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + + // delete with concrete + let vid = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 111, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + + // delete with mixed filter + vid = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 111, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id, duration: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid.id } })).resolves.toBeNull(); + + // delete not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { delete: { id: vid.id } } }, + }) + ).toBeNotFound(); + }); + + it('update nested deleteMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // delete with base no filter + await db.user.update({ + where: { id: user.id }, + data: { assets: { deleteMany: {} } }, + }); + await expect(db.asset.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toBeNull(); + + // delete with concrete + let vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + }, + }); + let vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { rating: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + await db.asset.deleteMany(); + + // delete with mixed args + vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + viewCount: 111, + }, + }); + vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + viewCount: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { url: 'abc', rating: 111, viewCount: 111 } } }, + }); + await expect(db.asset.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: vid1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: vid2.id } })).toResolveTruthy(); + await db.asset.deleteMany(); + + // delete not found + vid1 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'abc', + duration: 111, + rating: 111, + }, + }); + vid2 = await db.ratedVideo.create({ + data: { + user: { connect: { id: user.id } }, + owner: { connect: { id: user.id } }, + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + await db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { deleteMany: { url: 'abc', rating: 222 } } }, + }); + await expect(db.asset.count()).resolves.toBe(2); + }); + + it('update nested relation manipulation', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + // connect, disconnect with base + await expect( + db.user.update({ + where: { id: user.id }, + data: { assets: { disconnect: { id: video.id } } }, + include: { assets: true }, + }) + ).resolves.toMatchObject({ + assets: expect.arrayContaining([]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { assets: { connect: { id: video.id } } }, + include: { assets: true }, + }) + ).resolves.toMatchObject({ + assets: expect.arrayContaining([expect.objectContaining({ id: video.id })]), + }); + + /// connect, disconnect with concrete + + let vid1 = await db.ratedVideo.create({ + data: { + url: 'abc', + duration: 111, + rating: 111, + }, + }); + let vid2 = await db.ratedVideo.create({ + data: { + url: 'xyz', + duration: 222, + rating: 222, + }, + }); + + // connect not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { connect: [{ id: vid2.id + 1 }] } }, + include: { ratedVideos: true }, + }) + ).toBeRejectedWithCode(PrismaErrorCode.REQUIRED_CONNECTED_RECORD_NOT_FOUND); + + // connect found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { connect: [{ id: vid1.id, duration: vid1.duration, rating: vid1.rating }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + + // connectOrCreate + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + connectOrCreate: [ + { + where: { id: vid2.id, duration: 333 }, + create: { + url: 'xyz', + duration: 333, + rating: 333, + }, + }, + ], + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ duration: 333 })]), + }); + + // disconnect not found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { disconnect: [{ id: vid2.id }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + + // disconnect found + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { disconnect: [{ id: vid1.id, duration: vid1.duration, rating: vid1.rating }] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([]), + }); + + // set + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + set: [ + { id: vid1.id, viewCount: vid1.viewCount }, + { id: vid2.id, viewCount: vid2.viewCount }, + ], + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([ + expect.objectContaining({ id: vid1.id }), + expect.objectContaining({ id: vid2.id }), + ]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { ratedVideos: { set: [] } }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([]), + }); + await expect( + db.user.update({ + where: { id: user.id }, + data: { + ratedVideos: { + set: { id: vid1.id, viewCount: vid1.viewCount }, + }, + }, + include: { ratedVideos: true }, + }) + ).resolves.toMatchObject({ + ratedVideos: expect.arrayContaining([expect.objectContaining({ id: vid1.id })]), + }); + }); + + it('updateMany', async () => { + const { db, videoWithOwner: video, user } = await setup(); + const otherVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 10000, duration: 10000, url: 'xyz', rating: 10000 }, + }); + + // update only the current level + await expect( + db.ratedVideo.updateMany({ + where: { rating: video.rating, viewCount: video.viewCount }, + data: { rating: 100 }, + }) + ).resolves.toMatchObject({ count: 1 }); + let read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ rating: 100 }); + + // update with concrete + await expect( + db.ratedVideo.updateMany({ + where: { id: video.id }, + data: { viewCount: 1, duration: 11, rating: 101 }, + }) + ).resolves.toMatchObject({ count: 1 }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ viewCount: 1, duration: 11, rating: 101 }); + + // update with base + await db.video.updateMany({ + where: { viewCount: 1, duration: 11 }, + data: { viewCount: 2, duration: 12 }, + }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read).toMatchObject({ viewCount: 2, duration: 12 }); + + // update with base + await db.asset.updateMany({ + where: { viewCount: 2 }, + data: { viewCount: 3 }, + }); + read = await db.ratedVideo.findUnique({ where: { id: video.id } }); + expect(read.viewCount).toBe(3); + + // the other video is unchanged + await expect(await db.ratedVideo.findUnique({ where: { id: otherVideo.id } })).toMatchObject(otherVideo); + + // update with concrete no where + await expect( + db.ratedVideo.updateMany({ + data: { viewCount: 111, duration: 111, rating: 111 }, + }) + ).resolves.toMatchObject({ count: 2 }); + await expect(db.ratedVideo.findUnique({ where: { id: video.id } })).resolves.toMatchObject({ duration: 111 }); + await expect(db.ratedVideo.findUnique({ where: { id: otherVideo.id } })).resolves.toMatchObject({ + duration: 111, + }); + + // set discriminator + await expect(db.ratedVideo.updateMany({ data: { assetType: 'Image' } })).rejects.toThrow('is a discriminator'); + await expect(db.ratedVideo.updateMany({ data: { videoType: 'RatedVideo' } })).rejects.toThrow( + 'is a discriminator' + ); + }); + + it('upsert', async () => { + const { db, videoWithOwner: video, user } = await setup(); + + await expect( + db.asset.upsert({ + where: { id: video.id }, + create: { id: video.id, viewCount: 1 }, + update: { viewCount: 2 }, + }) + ).rejects.toThrow('is a delegate'); + + // update + await expect( + db.ratedVideo.upsert({ + where: { id: video.id }, + create: { + viewCount: 1, + duration: 300, + url: 'xyz', + rating: 100, + owner: { connect: { id: user.id } }, + }, + update: { duration: 200 }, + }) + ).resolves.toMatchObject({ + id: video.id, + duration: 200, + }); + + // create + const created = await db.ratedVideo.upsert({ + where: { id: video.id + 1 }, + create: { viewCount: 1, duration: 300, url: 'xyz', rating: 100, owner: { connect: { id: user.id } } }, + update: { duration: 200 }, + }); + expect(created.id).not.toEqual(video.id); + expect(created.duration).toBe(300); + }); + + it('delete', async () => { + let { db, user, video: ratedVideo } = await setup(); + + let deleted = await db.ratedVideo.delete({ + where: { id: ratedVideo.id }, + select: { rating: true, owner: true }, + }); + expect(deleted).toMatchObject({ rating: 100 }); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with base + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + const video = await db.video.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id }, include: { owner: true } }); + expect(deleted).toMatchObject(video); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with concrete + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + let asset = await db.asset.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id }, include: { owner: true } }); + expect(deleted).toMatchObject(asset); + expect(deleted.owner).toMatchObject(user); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + + // delete with combined condition + ratedVideo = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + asset = await db.asset.findUnique({ where: { id: ratedVideo.id } }); + deleted = await db.video.delete({ where: { id: ratedVideo.id, viewCount: 1 } }); + expect(deleted).toMatchObject(asset); + await expect(db.ratedVideo.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: ratedVideo.id } })).resolves.toBeNull(); + }); + + it('deleteMany', async () => { + const { enhance } = await loadSchema(schema, { logPrismaQuery: true, enhancements: ['delegate'] }); + const db = enhance(); + + const user = await db.user.create({ data: { id: 1 } }); + + // no where + let video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + let video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'xyz', rating: 100 }, + }); + await expect(db.ratedVideo.deleteMany()).resolves.toMatchObject({ count: 2 }); + await expect(db.ratedVideo.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: video1.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.video.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.asset.findUnique({ where: { id: video2.id } })).resolves.toBeNull(); + await expect(db.ratedVideo.count()).resolves.toBe(0); + + // with base + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.asset.deleteMany({ where: { viewCount: 1 } })).resolves.toMatchObject({ count: 1 }); + await expect(db.asset.count()).resolves.toBe(1); + await db.asset.deleteMany(); + + // where current level + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { rating: 100 } })).resolves.toMatchObject({ count: 1 }); + await expect(db.ratedVideo.count()).resolves.toBe(1); + await db.ratedVideo.deleteMany(); + + // where mixed with base level + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { viewCount: 1, duration: 100 } })).resolves.toMatchObject({ + count: 1, + }); + await expect(db.ratedVideo.count()).resolves.toBe(1); + await db.ratedVideo.deleteMany(); + + // delete not found + video1 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 1, duration: 100, url: 'abc', rating: 100 }, + }); + video2 = await db.ratedVideo.create({ + data: { owner: { connect: { id: user.id } }, viewCount: 2, duration: 200, url: 'xyz', rating: 200 }, + }); + await expect(db.ratedVideo.deleteMany({ where: { viewCount: 2, duration: 100 } })).resolves.toMatchObject({ + count: 0, + }); + await expect(db.ratedVideo.count()).resolves.toBe(2); + }); + + it('aggregate', async () => { + const { db } = await setup(); + + const aggregate = await db.ratedVideo.aggregate({ + _count: true, + _sum: { rating: true }, + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + orderBy: { + duration: 'desc', + }, + }); + expect(aggregate).toMatchObject({ _count: 1, _sum: { rating: 100 } }); + + expect(() => db.ratedVideo.aggregate({ _count: true, _sum: { rating: true, viewCount: true } })).toThrow( + 'aggregate with fields from base type is not supported yet' + ); + }); + + it('count', async () => { + const { db } = await setup(); + + let count = await db.ratedVideo.count(); + expect(count).toBe(1); + + count = await db.ratedVideo.count({ + select: { _all: true, rating: true }, + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + }); + expect(count).toMatchObject({ _all: 1, rating: 1 }); + + expect(() => db.ratedVideo.count({ select: { rating: true, viewCount: true } })).toThrow( + 'count with fields from base type is not supported yet' + ); + }); + + it('groupBy', async () => { + const { db, video } = await setup(); + + let group = await db.ratedVideo.groupBy({ by: ['rating'] }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ rating: video.rating }); + + group = await db.ratedVideo.groupBy({ + by: ['id', 'rating'], + where: { viewCount: { gt: 0 }, rating: { gt: 10 } }, + }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ id: video.id, rating: video.rating }); + + group = await db.ratedVideo.groupBy({ + by: ['id'], + _sum: { rating: true }, + }); + expect(group).toHaveLength(1); + expect(group[0]).toMatchObject({ id: video.id, _sum: { rating: video.rating } }); + + group = await db.ratedVideo.groupBy({ + by: ['id'], + _sum: { rating: true }, + having: { rating: { _sum: { gt: video.rating } } }, + }); + expect(group).toHaveLength(0); + + expect(() => db.ratedVideo.groupBy({ by: 'viewCount' })).toThrow( + 'groupBy with fields from base type is not supported yet' + ); + expect(() => db.ratedVideo.groupBy({ having: { rating: { gt: 0 }, viewCount: { gt: 0 } } })).toThrow( + 'groupBy with fields from base type is not supported yet' + ); + }); +}); diff --git a/tests/integration/tests/schema/petstore.zmodel b/tests/integration/tests/schema/petstore.zmodel index 77ec1e643..42a279550 100644 --- a/tests/integration/tests/schema/petstore.zmodel +++ b/tests/integration/tests/schema/petstore.zmodel @@ -5,7 +5,6 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } plugin zod { diff --git a/tests/integration/tests/schema/todo.zmodel b/tests/integration/tests/schema/todo.zmodel index 733391bd1..c3a84707e 100644 --- a/tests/integration/tests/schema/todo.zmodel +++ b/tests/integration/tests/schema/todo.zmodel @@ -9,7 +9,6 @@ datasource db { generator js { provider = 'prisma-client-js' - previewFeatures = ['clientExtensions'] } plugin zod {