diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index aeb7bdd83..08734dc0c 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -7,24 +7,28 @@ import { emitProject, getDataModels, getLiteral, + getPrismaClientImportSpec, hasAttribute, + isEnumFieldReference, isForeignKeyField, resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { DataModel, DataSource, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; +import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; import { AggregateOperationSupport, addMissingInputObjectTypes, resolveAggregateOperationSupport, } from '@zenstackhq/sdk/dmmf-helpers'; import { promises as fs } from 'fs'; +import { streamAllContents } from 'langium'; import path from 'path'; import { Project } from 'ts-morph'; +import { upperCaseFirst } from 'upper-case-first'; +import { isFromStdlib } from '../../language-server/utils'; import { getDefaultOutputFolder } from '../plugin-utils'; import Transformer from './transformer'; import removeDir from './utils/removeDir'; -import { upperCaseFirst } from 'upper-case-first'; import { makeFieldSchema, makeValidationRefinements } from './utils/schema-gen'; export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.Document) { @@ -176,8 +180,6 @@ async function generateModelSchema(model: DataModel, project: Project, output: s overwrite: true, }); sf.replaceWithText((writer) => { - writer.writeLine('/* eslint-disable */'); - const fields = model.fields.filter( (field) => !AUXILIARY_FIELDS.includes(field.name) && @@ -186,9 +188,25 @@ async function generateModelSchema(model: DataModel, project: Project, output: s !isForeignKeyField(field) ); + writer.writeLine('/* eslint-disable */'); writer.writeLine(`import { z } from 'zod';`); - // import enums + // import user-defined enums from Prisma as they might be referenced in the expressions + const importEnums = new Set(); + for (const node of streamAllContents(model)) { + if (isEnumFieldReference(node)) { + const field = node.target.ref as EnumField; + if (!isFromStdlib(field.$container)) { + importEnums.add(field.$container.name); + } + } + } + if (importEnums.size > 0) { + const prismaImport = getPrismaClientImportSpec(model.$container, path.join(output, 'models')); + writer.writeLine(`import { ${[...importEnums].join(', ')} } from '${prismaImport}';`); + } + + // import enum schemas for (const field of fields) { if (field.type.reference?.ref && isEnum(field.type.reference?.ref)) { const name = upperCaseFirst(field.type.reference?.ref.name); @@ -205,9 +223,9 @@ async function generateModelSchema(model: DataModel, project: Project, output: s }); writer.writeLine(');'); + // compile "@@validate" to ".refine" const refinements = makeValidationRefinements(model); if (refinements.length > 0) { - console.log('Generated refinements:', refinements); writer.writeLine(`function refine(schema: z.ZodType) { return schema${refinements.join('\n')}; }`); } diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 3a1cdb72d..3b26c01f1 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -133,8 +133,8 @@ export function getAttributeArgLiteral( return undefined; } -export function isEnumFieldReference(expr: Expression): expr is ReferenceExpr { - return isReferenceExpr(expr) && isEnumField(expr.target.ref); +export function isEnumFieldReference(node: AstNode): node is ReferenceExpr { + return isReferenceExpr(node) && isEnumField(node.target.ref); } /**