diff --git a/packages/ide/jetbrains/package.json b/packages/ide/jetbrains/package.json index 274e88c2a..4e7fc26df 100644 --- a/packages/ide/jetbrains/package.json +++ b/packages/ide/jetbrains/package.json @@ -6,7 +6,7 @@ "homepage": "https://zenstack.dev", "private": true, "scripts": { - "build": "./gradlew buildPlugin" + "build": "./gradlew buildPlugin" }, "author": "ZenStack Team", "license": "MIT", diff --git a/packages/runtime/src/cross/model-meta.ts b/packages/runtime/src/cross/model-meta.ts index a38f7986d..401caeaf2 100644 --- a/packages/runtime/src/cross/model-meta.ts +++ b/packages/runtime/src/cross/model-meta.ts @@ -8,6 +8,11 @@ export type RuntimeAttribute = { args: Array<{ name?: string; value: unknown }>; }; +/** + * Function for computing default value for a field + */ +export type FieldDefaultValueProvider = (userContext: unknown) => unknown; + /** * Runtime information of a data model field */ @@ -67,6 +72,11 @@ export type FieldInfo = { */ foreignKeyMapping?: Record; + /** + * A function that provides a default value for the field + */ + defaultValueProvider?: FieldDefaultValueProvider; + /** * If the field is an auto-increment field */ diff --git a/packages/runtime/src/cross/nested-write-visitor.ts b/packages/runtime/src/cross/nested-write-visitor.ts index 7d67f6d9b..477117dbd 100644 --- a/packages/runtime/src/cross/nested-write-visitor.ts +++ b/packages/runtime/src/cross/nested-write-visitor.ts @@ -34,7 +34,7 @@ export type NestedWriteVisitorContext = { * to let the visitor traverse it instead of its original children. */ export type NestedWriterVisitorCallback = { - create?: (model: string, args: any[], context: NestedWriteVisitorContext) => MaybePromise; + create?: (model: string, data: any, context: NestedWriteVisitorContext) => MaybePromise; createMany?: ( model: string, diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index a82640905..e3204cd52 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -7,6 +7,7 @@ import { withPassword } from './password'; import { withPolicy } from './policy'; import type { ErrorTransformer } from './proxy'; import type { PolicyDef, ZodSchemas } from './types'; +import { withDefaultAuth } from './default-auth'; /** * Kinds of enhancements to `PrismaClient` @@ -15,6 +16,7 @@ export enum EnhancementKind { Password = 'password', Omit = 'omit', Policy = 'policy', + DefaultAuth = 'defaultAuth', } /** @@ -92,6 +94,7 @@ export type EnhancementContext = { let hasPassword: boolean | undefined = undefined; let hasOmit: boolean | undefined = undefined; +let hasDefaultAuth: boolean | undefined = undefined; /** * Gets a Prisma client enhanced with all enhancement behaviors, including access @@ -120,13 +123,24 @@ export function createEnhancement( let result = prisma; - if (hasPassword === undefined || hasOmit === undefined) { + if ( + process.env.ZENSTACK_TEST === '1' || // avoid caching in tests + hasPassword === undefined || + hasOmit === undefined || + hasDefaultAuth === undefined + ) { const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo)); hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password')); hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit')); + hasDefaultAuth = allFields.some((field) => field.defaultValueProvider); } - const kinds = options.kinds ?? [EnhancementKind.Password, EnhancementKind.Omit, EnhancementKind.Policy]; + const kinds = options.kinds ?? [ + EnhancementKind.Password, + EnhancementKind.Omit, + EnhancementKind.Policy, + EnhancementKind.DefaultAuth, + ]; if (hasPassword && kinds.includes(EnhancementKind.Password)) { // @password proxy @@ -138,6 +152,11 @@ export function createEnhancement( result = withOmit(result, options); } + if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) { + // @default(auth()) proxy + result = withDefaultAuth(result, options, context); + } + // policy proxy if (kinds.includes(EnhancementKind.Policy)) { result = withPolicy(result, options, context); diff --git a/packages/runtime/src/enhancements/default-auth.ts b/packages/runtime/src/enhancements/default-auth.ts new file mode 100644 index 000000000..48af0ed73 --- /dev/null +++ b/packages/runtime/src/enhancements/default-auth.ts @@ -0,0 +1,102 @@ +/* eslint-disable @typescript-eslint/no-unused-vars */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import deepcopy from 'deepcopy'; +import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross'; +import { DbClientContract } from '../types'; +import { EnhancementContext, EnhancementOptions } from './create-enhancement'; +import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy'; + +/** + * Gets an enhanced Prisma client that supports `@default(auth())` attribute. + * + * @private + */ +export function withDefaultAuth( + prisma: DbClient, + options: EnhancementOptions, + context?: EnhancementContext +): DbClient { + return makeProxy( + prisma, + options.modelMeta, + (_prisma, model) => new DefaultAuthHandler(_prisma as DbClientContract, model, options, context), + 'defaultAuth' + ); +} + +class DefaultAuthHandler extends DefaultPrismaProxyHandler { + private readonly db: DbClientContract; + private readonly userContext: any; + + constructor( + prisma: DbClientContract, + model: string, + private readonly options: EnhancementOptions, + private readonly context?: EnhancementContext + ) { + super(prisma, model); + this.db = prisma; + + if (!this.context?.user) { + throw new Error(`Using \`auth()\` in \`@default\` requires a user context`); + } + + this.userContext = this.context.user; + } + + // base override + protected async preprocessArgs(action: PrismaProxyActions, args: any) { + const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert']; + if (actionsOfInterest.includes(action)) { + const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args); + return newArgs; + } + return args; + } + + private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) { + const newArgs = deepcopy(args); + + const processCreatePayload = (model: string, data: any) => { + const fields = getFields(this.options.modelMeta, model); + for (const fieldInfo of Object.values(fields)) { + if (fieldInfo.name in data) { + // create payload already sets field value + continue; + } + + if (!fieldInfo.defaultValueProvider) { + // field doesn't have a runtime default value provider + continue; + } + + const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo); + if (authDefaultValue !== undefined) { + // set field value extracted from `auth()` + data[fieldInfo.name] = authDefaultValue; + } + } + }; + + // visit create payload and set default value to fields using `auth()` in `@default()` + const visitor = new NestedWriteVisitor(this.options.modelMeta, { + create: (model, data) => { + processCreatePayload(model, data); + }, + + createMany: (model, args) => { + for (const item of enumerate(args.data)) { + processCreatePayload(model, item); + } + }, + }); + + await visitor.visit(model, action, newArgs); + return newArgs; + } + + private getDefaultValueFromAuth(fieldInfo: FieldInfo) { + return fieldInfo.defaultValueProvider?.(this.userContext); + } +} diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index ba2f9a2d8..2879a3119 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -22,3 +22,18 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error { throw new prismaModule.PrismaClientUnknownRequestError(...args); } + +export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown { + if (path === undefined || path === '') { + return defaultValue; + } + const keys = Array.isArray(path) ? path : path.split('.'); + for (const key of keys) { + if (object && typeof object === 'object' && key in object) { + object = object[key as keyof typeof object]; + } else { + return defaultValue; + } + } + return object !== undefined ? object : defaultValue; +} diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 7644521b8..cfc8a39af 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -3,16 +3,16 @@ import { Expression, ExpressionType, isDataModel, + isDataModelField, isEnum, + isLiteralExpr, isMemberAccessExpr, isNullExpr, isThisExpr, - isDataModelField, - isLiteralExpr, } from '@zenstackhq/language/ast'; -import { isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; +import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk'; import { ValidationAcceptor } from 'langium'; -import { getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils'; +import { getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; @@ -132,18 +132,24 @@ export default class ExpressionValidator implements AstValidator { // - foo.user.id == userId // except: // - future().userId == userId - if(isMemberAccessExpr(expr.left) && isDataModelField(expr.left.member.ref) && expr.left.member.ref.$container != getContainingDataModel(expr) - || isMemberAccessExpr(expr.right) && isDataModelField(expr.right.member.ref) && expr.right.member.ref.$container != getContainingDataModel(expr)) - { + if ( + (isMemberAccessExpr(expr.left) && + isDataModelField(expr.left.member.ref) && + expr.left.member.ref.$container != getContainingDataModel(expr)) || + (isMemberAccessExpr(expr.right) && + isDataModelField(expr.right.member.ref) && + expr.right.member.ref.$container != getContainingDataModel(expr)) + ) { // foo.user.id == auth().id // foo.user.id == "123" // foo.user.id == null // foo.user.id == EnumValue - if(!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) - { - accept('error', 'comparison between fields of different models are not supported', { node: expr }); - break; - } + if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) { + accept('error', 'comparison between fields of different models are not supported', { + node: expr, + }); + break; + } } if ( @@ -205,14 +211,13 @@ export default class ExpressionValidator implements AstValidator { } } - private isNotModelFieldExpr(expr: Expression) { - return isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + return ( + isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr) + ); } private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } - } - diff --git a/packages/schema/src/language-server/validator/function-invocation-validator.ts b/packages/schema/src/language-server/validator/function-invocation-validator.ts index 3bc364bd2..50b974a53 100644 --- a/packages/schema/src/language-server/validator/function-invocation-validator.ts +++ b/packages/schema/src/language-server/validator/function-invocation-validator.ts @@ -11,10 +11,15 @@ import { isDataModelFieldAttribute, isLiteralExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk'; +import { + ExpressionContext, + getDataModelFieldReference, + getFunctionExpressionContext, + isEnumFieldReference, + isFromStdlib, +} from '@zenstackhq/sdk'; import { AstNode, ValidationAcceptor } from 'langium'; import { P, match } from 'ts-pattern'; -import { getDataModelFieldReference } from '../../utils/ast-utils'; import { AstValidator } from '../types'; import { typeAssignable } from './utils'; diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index ef97cf4b6..8c8fb2c98 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,7 +35,7 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { getContainingModel, hasAttribute, isFromStdlib } from '@zenstackhq/sdk'; +import { getContainingModel, hasAttribute, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -52,12 +52,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { - getAllDeclarationsFromImports, - getContainingDataModel, - isAuthInvocation, - isCollectionPredicate, -} from '../utils/ast-utils'; +import { getAllDeclarationsFromImports, getContainingDataModel, isCollectionPredicate } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -329,7 +324,7 @@ export class ZModelLinker extends DefaultLinker { if (node.function.ref) { // eslint-disable-next-line @typescript-eslint/ban-types const funcDecl = node.function.ref as FunctionDecl; - if (funcDecl.name === 'auth' && isFromStdlib(funcDecl)) { + if (isAuthInvocation(node)) { // auth() function is resolved to User model in the current document const model = getContainingModel(node); @@ -346,7 +341,7 @@ export class ZModelLinker extends DefaultLinker { node.$resolvedType = { decl: authModel, nullable: true }; } } - } else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) { + } else if (isFutureExpr(node)) { // future() function is resolved to current model node.$resolvedType = { decl: getContainingDataModel(node) }; } else { diff --git a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts index 0cc80c7ea..e38a34c29 100644 --- a/packages/schema/src/plugins/enhancer/policy/expression-writer.ts +++ b/packages/schema/src/plugins/enhancer/policy/expression-writer.ts @@ -19,19 +19,18 @@ import { import { ExpressionContext, getFunctionExpressionContext, + getIdFields, getLiteral, + isAuthInvocation, isDataModelFieldReference, isFutureExpr, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, } from '@zenstackhq/sdk'; import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '..'; -import { getIdFields, isAuthInvocation } from '../../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<='; diff --git a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts index e5017383d..149858cd6 100644 --- a/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts @@ -33,14 +33,18 @@ import { PluginError, PluginOptions, RUNTIME_PACKAGE, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, analyzePolicies, getAttributeArg, getAuthModel, getDataModels, + getIdFields, getLiteral, getPrismaClientImportSpec, hasAttribute, hasValidationAttributes, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isFromStdlib, @@ -52,11 +56,7 @@ import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; import { FunctionDeclaration, Project, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph'; import { name } from '..'; -import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../../utils/ast-utils'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; +import { isCollectionPredicate } from '../../../utils/ast-utils'; import { ALL_OPERATION_KINDS } from '../../plugin-utils'; import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 64777b62e..68336baeb 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -310,7 +310,6 @@ export class FunctionCallArg { return this.name ? `${this.name}: ${this.value}` : this.value; } } - export class Enum extends ContainerDeclaration { public fields: EnumField[] = []; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index feee0f3d1..0f25ab1b8 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -33,6 +33,7 @@ import { getDMMF, getLiteral, getPrismaVersion, + isDefaultAuthField, PluginError, PluginOptions, resolved, @@ -311,9 +312,7 @@ export default class PrismaSchemaGenerator { const type = new ModelFieldType(fieldType, field.type.array, field.type.optional); - const attributes = field.attributes - .filter((attr) => this.isPrismaAttribute(attr)) - .map((attr) => this.makeFieldAttribute(attr)); + const attributes = this.getAttributesToGenerate(field); const nonPrismaAttributes = field.attributes.filter((attr) => attr.decl.ref && !this.isPrismaAttribute(attr)); @@ -325,6 +324,15 @@ export default class PrismaSchemaGenerator { field.comments.forEach((c) => result.addComment(c)); } + private getAttributesToGenerate(field: DataModelField) { + if (isDefaultAuthField(field)) { + return []; + } + return field.attributes + .filter((attr) => this.isPrismaAttribute(attr)) + .map((attr) => this.makeFieldAttribute(attr)); + } + private makeFieldAttribute(attr: DataModelFieldAttribute) { const attrName = resolved(attr.decl).name; if (attrName === FIELD_PASSTHROUGH_ATTR) { diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 802127c58..02607d4c7 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,6 +1,8 @@ import { ExpressionContext, PluginError, + TypeScriptExpressionTransformer, + TypeScriptExpressionTransformerError, getAttributeArg, getAttributeArgLiteral, getLiteral, @@ -18,10 +20,6 @@ import { } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; -import { - TypeScriptExpressionTransformer, - TypeScriptExpressionTransformerError, -} from '../../../utils/typescript-expression-transformer'; export function makeFieldSchema(field: DataModelField, respectDefault = false) { if (isDataModel(field.type.reference?.ref)) { diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 1a9446d7b..f755bb3df 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -73,7 +73,7 @@ function env(name: String): String { * Gets the current login user. */ function auth(): Any { -} @@@expressionContext([AccessPolicy]) +} @@@expressionContext([DefaultValue, AccessPolicy]) /** * Gets current date-time (as DateTime type). @@ -204,7 +204,7 @@ attribute @id(map: String?, length: Int?, sort: String?, clustered: Boolean?) @@ /** * Defines a default value for a field. - * @param value: An expression (e.g. 5, true, now()). + * @param value: An expression (e.g. 5, true, now(), auth()). */ attribute @default(_ value: ContextType, map: String?) @@@prisma diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 661f14b26..80543d6a2 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,21 +1,13 @@ import { BinaryExpr, DataModel, - DataModelField, Expression, - isArrayExpr, isBinaryExpr, isDataModel, - isDataModelField, - isInvocationExpr, - isMemberAccessExpr, isModel, - isReferenceExpr, Model, ModelImport, - ReferenceExpr, } from '@zenstackhq/language/ast'; -import { isFromStdlib } from '@zenstackhq/sdk'; import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; import { URI, Utils } from 'vscode-uri'; @@ -56,43 +48,6 @@ function updateContainer(nodes: T[], container: AstNode): Mut }); } -export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.$resolvedFields.find((f) => - f.attributes.some((attr) => attr.decl.$refText === '@id') - ); - if (fieldLevelId) { - return [fieldLevelId]; - } else { - // get model level @@id attribute - const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); - if (modelIdAttr) { - // get fields referenced in the attribute: @@id([field1, field2]]) - if (!isArrayExpr(modelIdAttr.args[0].value)) { - return []; - } - const argValue = modelIdAttr.args[0].value; - return argValue.items - .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) - .map((expr) => expr.target.ref as DataModelField); - } - } - return []; -} - -export function isAuthInvocation(node: AstNode) { - return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); -} - -export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { - if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { - return expr.target.ref; - } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { - return expr.member.ref; - } else { - return undefined; - } -} - export function resolveImportUri(imp: ModelImport): URI | undefined { if (imp.path === undefined || imp.path.length === 0) { return undefined; @@ -157,7 +112,6 @@ export function isCollectionPredicate(node: AstNode): node is BinaryExpr { return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator); } - export function getContainingDataModel(node: Expression): DataModel | undefined { let curr: AstNode | undefined = node.$container; while (curr) { @@ -167,4 +121,4 @@ export function getContainingDataModel(node: Expression): DataModel | undefined curr = curr.$container; } return undefined; -} \ No newline at end of file +} diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 8d295d143..d2f425e53 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -123,6 +123,7 @@ describe('Prisma generator test', () => { id String @id @default(nanoid(6)) x String @default(nanoid()) y String @default(dbgenerated("gen_random_uuid()")) + z String @default(auth().id) } `); @@ -142,6 +143,7 @@ describe('Prisma generator test', () => { expect(content).toContain('@default(nanoid(6))'); expect(content).toContain('@default(nanoid())'); expect(content).toContain('@default(dbgenerated("gen_random_uuid()"))'); + expect(content).not.toContain('@default(auth().id)'); }); it('triple slash comments', async () => { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 8b7886334..cb2f788d4 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1009,6 +1009,35 @@ describe('Attribute tests', () => { }); it('auth function check', async () => { + await loadModel(` + ${prelude} + + model User { + id String @id + name String + } + model B { + id String @id + userId String @default(auth().id) + userName String @default(auth().name) + } + `); + + // expect( + // await loadModelWithError(` + // ${prelude} + + // model User { + // id String @id + // name String + // } + // model B { + // id String @id + // userData String @default(auth()) + // } + // `) + // ).toContain("Value is not assignable to parameter"); + expect( await loadModelWithError(` ${prelude} @@ -1124,14 +1153,14 @@ describe('Attribute tests', () => { }); it('incorrect function expression context', async () => { - expect( - await loadModelWithError(` - ${prelude} - model M { - id String @id @default(auth()) - } - `) - ).toContain('function "auth" is not allowed in the current context: DefaultValue'); + // expect( + // await loadModelWithError(` + // ${prelude} + // model M { + // id String @id @default(auth()) + // } + // `) + // ).toContain('function "auth" is not allowed in the current context: DefaultValue'); expect( await loadModelWithError(` diff --git a/packages/sdk/package.json b/packages/sdk/package.json index beddaad70..ac8bcaf1d 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -23,10 +23,12 @@ "@prisma/internals-v5": "npm:@prisma/internals@^5.0.0", "@zenstackhq/language": "workspace:*", "@zenstackhq/runtime": "workspace:*", + "langium": "1.2.0", "lower-case-first": "^2.0.2", "prettier": "^2.8.3 || 3.x", "semver": "^7.5.2", "ts-morph": "^16.0.0", + "ts-pattern": "^4.3.0", "upper-case-first": "^2.0.2" }, "devDependencies": { diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 64060390e..5013267e8 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -4,6 +4,7 @@ export { generate as generateModelMeta } from './model-meta-generator'; export * from './policy'; export * from './prisma'; export * from './types'; +export * from './typescript-expression-transformer'; export * from './utils'; export * from './validation'; export * from './zmodel-code-generator'; diff --git a/packages/sdk/src/model-meta-generator.ts b/packages/sdk/src/model-meta-generator.ts index 2692706c1..9beda653a 100644 --- a/packages/sdk/src/model-meta-generator.ts +++ b/packages/sdk/src/model-meta-generator.ts @@ -12,9 +12,11 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import type { RuntimeAttribute } from '@zenstackhq/runtime'; +import { streamAst } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; -import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; +import { CodeBlockWriter, Project, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { + ExpressionContext, getAttribute, getAttributeArg, getAttributeArgs, @@ -22,10 +24,12 @@ import { getDataModels, getLiteral, hasAttribute, + isAuthInvocation, isEnumFieldReference, isForeignKeyField, isIdField, resolved, + TypeScriptExpressionTransformer, } from '.'; export type ModelMetaGeneratorOptions = { @@ -38,13 +42,20 @@ export async function generate(project: Project, models: DataModel[], options: M sf.addStatements('/* eslint-disable */'); sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(models, writer, options) }], + declarations: [ + { name: 'metadata', initializer: (writer) => generateModelMetadata(models, sf, writer, options) }, + ], }); sf.addStatements('export default metadata;'); return sf; } -function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter, options: ModelMetaGeneratorOptions) { +function generateModelMetadata( + dataModels: DataModel[], + sourceFile: SourceFile, + writer: CodeBlockWriter, + options: ModelMetaGeneratorOptions +) { writer.block(() => { writer.write('fields:'); writer.block(() => { @@ -120,6 +131,12 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter, foreignKeyMapping: ${JSON.stringify(fkMapping)},`); } + const defaultValueProvider = generateDefaultValueProvider(f, sourceFile); + if (defaultValueProvider) { + writer.write(` + defaultValueProvider: ${defaultValueProvider},`); + } + if (isAutoIncrement(f)) { writer.write(` isAutoIncrement: true,`); @@ -334,6 +351,39 @@ function getDeleteCascades(model: DataModel): string[] { .map((m) => m.name); } +function generateDefaultValueProvider(field: DataModelField, sourceFile: SourceFile) { + const defaultAttr = getAttribute(field, '@default'); + if (!defaultAttr) { + return undefined; + } + + const expr = defaultAttr.args[0]?.value; + if (!expr) { + return undefined; + } + + // find `auth()` in default value expression + const hasAuth = streamAst(expr).some(isAuthInvocation); + if (!hasAuth) { + return undefined; + } + + // generates a provider function like: + // function $default$Model$field(user: any) { ... } + const func = sourceFile.addFunction({ + name: `$default$${field.$container.name}$${field.name}`, + parameters: [{ name: 'user', type: 'any' }], + returnType: 'unknown', + statements: (writer) => { + const tsWriter = new TypeScriptExpressionTransformer({ context: ExpressionContext.DefaultValue }); + const code = tsWriter.transform(expr, false); + writer.write(`return ${code};`); + }, + }); + + return func.getName(); +} + function isAutoIncrement(field: DataModelField) { const defaultAttr = getAttribute(field, '@default'); if (!defaultAttr) { diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/sdk/src/typescript-expression-transformer.ts similarity index 98% rename from packages/schema/src/utils/typescript-expression-transformer.ts rename to packages/sdk/src/typescript-expression-transformer.ts index cd868d76c..20585118c 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/sdk/src/typescript-expression-transformer.ts @@ -17,9 +17,9 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk'; import { match, P } from 'ts-pattern'; -import { getIdFields } from './ast-utils'; +import { ExpressionContext } from './constants'; +import { getIdFields, getLiteral, isFromStdlib, isFutureExpr } from './utils'; export class TypeScriptExpressionTransformerError extends Error { constructor(message: string) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index afd043565..2f046b692 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -22,6 +22,7 @@ import { isGeneratorDecl, isInvocationExpr, isLiteralExpr, + isMemberAccessExpr, isModel, isObjectExpr, isReferenceExpr, @@ -280,6 +281,13 @@ export function isForeignKeyField(field: DataModelField) { }); } +export function isDefaultAuthField(field: DataModelField) { + return ( + hasAttribute(field, '@default') && + !!field.attributes.find((attr) => attr.args?.[0]?.value.$cstNode?.text.startsWith('auth()')) + ); +} + export function resolvePath(_path: string, options: Pick) { if (path.isAbsolute(_path)) { return _path; @@ -334,7 +342,11 @@ export function getFunctionExpressionContext(funcDecl: FunctionDecl) { } export function isFutureExpr(node: AstNode) { - return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); + return isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref); +} + +export function isAuthInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } export function isFromStdlib(node: AstNode) { @@ -373,3 +385,36 @@ export function getAuthModel(dataModels: DataModel[]) { } return authModel; } + +export function getIdFields(dataModel: DataModel) { + const fieldLevelId = dataModel.$resolvedFields.find((f) => + f.attributes.some((attr) => attr.decl.$refText === '@id') + ); + if (fieldLevelId) { + return [fieldLevelId]; + } else { + // get model level @@id attribute + const modelIdAttr = dataModel.attributes.find((attr) => attr.decl?.ref?.name === '@@id'); + if (modelIdAttr) { + // get fields referenced in the attribute: @@id([field1, field2]]) + if (!isArrayExpr(modelIdAttr.args[0].value)) { + return []; + } + const argValue = modelIdAttr.args[0].value; + return argValue.items + .filter((expr): expr is ReferenceExpr => isReferenceExpr(expr) && !!getDataModelFieldReference(expr)) + .map((expr) => expr.target.ref as DataModelField); + } + } + return []; +} + +export function getDataModelFieldReference(expr: Expression): DataModelField | undefined { + if (isReferenceExpr(expr) && isDataModelField(expr.target.ref)) { + return expr.target.ref; + } else if (isMemberAccessExpr(expr) && isDataModelField(expr.member.ref)) { + return expr.member.ref; + } else { + return undefined; + } +} diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 8af4092a1..1bbcc402d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -620,6 +620,9 @@ importers: '@zenstackhq/runtime': specifier: workspace:* version: link:../runtime/dist + langium: + specifier: 1.2.0 + version: 1.2.0 lower-case-first: specifier: ^2.0.2 version: 2.0.2 @@ -632,6 +635,9 @@ importers: ts-morph: specifier: ^16.0.0 version: 16.0.0 + ts-pattern: + specifier: ^4.3.0 + version: 4.3.0 upper-case-first: specifier: ^2.0.2 version: 2.0.2 diff --git a/tests/integration/tests/enhancements/with-policy/auth.test.ts b/tests/integration/tests/enhancements/with-policy/auth.test.ts index 942d2d579..f5b4e2f4f 100644 --- a/tests/integration/tests/enhancements/with-policy/auth.test.ts +++ b/tests/integration/tests/enhancements/with-policy/auth.test.ts @@ -363,4 +363,146 @@ describe('With Policy: auth() test', () => { enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload) ).toResolveTruthy(); }); + + it('Default auth() on literal fields', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + name String + score Int + + } + + model Post { + id String @id @default(uuid()) + title String + score Int? @default(auth().score) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userDb = enhance({ id: '1', name: 'user1', score: 10 }); + await expect(userDb.post.create({ data: { title: 'abc' } })).toResolveTruthy(); + await expect(userDb.post.findMany()).resolves.toHaveLength(1); + await expect(userDb.post.count({ where: { authorName: 'user1', score: 10 } })).resolves.toBe(1); + }); + + it('Default auth() data should not override passed args', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + name String + + } + + model Post { + id String @id @default(uuid()) + authorName String? @default(auth().name) + + @@allow('all', true) + } + ` + ); + + const userContextName = 'user1'; + const overrideName = 'no-default-auth-name'; + const userDb = enhance({ id: '1', name: userContextName }); + await expect(userDb.post.create({ data: { authorName: overrideName } })).toResolveTruthy(); + await expect(userDb.post.count({ where: { authorName: overrideName } })).resolves.toBe(1); + }); + + it('Default auth() with foreign key', async () => { + const { enhance, modelMeta } = await loadSchema( + ` + model User { + id String @id + posts Post[] + + @@allow('all', true) + + } + + model Post { + id String @id @default(uuid()) + title String + author User @relation(fields: [authorId], references: [id]) + authorId String @default(auth().id) + + @@allow('all', true) + } + ` + ); + + const db = enhance({ id: 'userId-1' }); + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + await expect(db.post.create({ data: { title: 'abc' } })).resolves.toMatchObject({ authorId: 'userId-1' }); + }); + + it('Default auth() with nested user context value', async () => { + const { enhance } = await loadSchema( + ` + model User { + id String @id + profile Profile? + posts Post[] + + @@allow('all', true) + } + + model Profile { + id String @id @default(uuid()) + image Image? + user User @relation(fields: [userId], references: [id]) + userId String @unique + } + + model Image { + id String @id @default(uuid()) + url String + profile Profile @relation(fields: [profileId], references: [id]) + profileId String @unique + } + + model Post { + id String @id @default(uuid()) + title String + defaultImageUrl String @default(auth().profile.image.url) + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('all', true) + } + ` + ); + const url = 'https://zenstack.dev'; + const db = enhance({ id: 'userId-1', profile: { image: { url } } }); + + // top-level create + await expect(db.user.create({ data: { id: 'userId-1' } })).toResolveTruthy(); + await expect( + db.post.create({ data: { title: 'abc', author: { connect: { id: 'userId-1' } } } }) + ).resolves.toMatchObject({ defaultImageUrl: url }); + + // nested create + let result = await db.user.create({ + data: { + id: 'userId-2', + posts: { + create: [{ title: 'p1' }, { title: 'p2' }], + }, + }, + include: { posts: true }, + }); + expect(result.posts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ title: 'p1', defaultImageUrl: url }), + expect.objectContaining({ title: 'p2', defaultImageUrl: url }), + ]) + ); + }); });