diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 1f3687166..04c903577 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -4,6 +4,8 @@ import { PluginOptions, createProject, emitProject, + getAttribute, + getAttributeArg, getDataModels, getLiteral, getPrismaClientImportSpec, @@ -15,7 +17,16 @@ import { resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; +import { + DataModel, + DataModelField, + DataSource, + EnumField, + Model, + isDataModel, + isDataSource, + isEnum, +} from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; import { promises as fs } from 'fs'; import { streamAllContents } from 'langium'; @@ -262,10 +273,17 @@ async function generateModelSchema(model: DataModel, project: Project, output: s sf.replaceWithText((writer) => { const fields = model.fields.filter( (field) => - // scalar fields only + // regular fields only !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) ); + const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref)); + const fkFields = model.fields.filter((field) => isForeignKeyField(field)); + // unsafe version of relations: including foreign keys and relation fields without fk + const unsafeRelations = model.fields.filter( + (field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field)) + ); + writer.writeLine('/* eslint-disable */'); writer.writeLine(`import { z } from 'zod';`); @@ -302,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s writer.writeLine(`import { Decimal } from 'decimal.js';`); } - // create base schema + // base schema writer.write(`const baseSchema = z.object(`); writer.inlineBlock(() => { fields.forEach((field) => { @@ -311,18 +329,63 @@ async function generateModelSchema(model: DataModel, project: Project, output: s }); writer.writeLine(');'); + // relation fields + + let allRelationSchema: string | undefined; + let safeRelationSchema: string | undefined; + let unsafeRelationSchema: string | undefined; + + if (relations.length > 0 || fkFields.length > 0) { + allRelationSchema = 'allRelationSchema'; + writer.write(`const ${allRelationSchema} = z.object(`); + writer.inlineBlock(() => { + [...relations, ...fkFields].forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + }); + }); + writer.writeLine(');'); + } + + if (relations.length > 0) { + safeRelationSchema = 'safeRelationSchema'; + writer.write(`const ${safeRelationSchema} = z.object(`); + writer.inlineBlock(() => { + relations.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); + }); + }); + writer.writeLine(');'); + } + + if (unsafeRelations.length > 0) { + unsafeRelationSchema = 'unsafeRelationSchema'; + writer.write(`const ${unsafeRelationSchema} = z.object(`); + writer.inlineBlock(() => { + unsafeRelations.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); + }); + }); + writer.writeLine(');'); + } + // compile "@@validate" to ".refine" const refinements = makeValidationRefinements(model); + let refineFuncName: string | undefined; if (refinements.length > 0) { + refineFuncName = `refine${upperCaseFirst(model.name)}`; writer.writeLine( - `function refine(schema: z.ZodType) { return schema${refinements.join( + `export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( '\n' )}; }` ); } - // model schema + //////////////////////////////////////////////// + // 1. Model schema + //////////////////////////////////////////////// let modelSchema = 'baseSchema'; + + // omit fields const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit')); if (fieldsToOmit.length > 0) { modelSchema = makeOmit( @@ -330,12 +393,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s fieldsToOmit.map((f) => f.name) ); } - if (refinements.length > 0) { - modelSchema = `refine(${modelSchema})`; + + if (allRelationSchema) { + // export schema with only scalar fields + const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`; + writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`); + modelSchema = modelScalarSchema; + + // merge relations + modelSchema = makeMerge(modelSchema, allRelationSchema); + } + + // refine + if (refineFuncName) { + const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`); + modelSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`); - // create schema + //////////////////////////////////////////////// + // 2. Create schema + //////////////////////////////////////////////// let createSchema = 'baseSchema'; const fieldsWithDefault = fields.filter( (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array @@ -346,25 +425,84 @@ async function generateModelSchema(model: DataModel, project: Project, output: s fieldsWithDefault.map((f) => f.name) ); } - if (refinements.length > 0) { - createSchema = `refine(${createSchema})`; + + if (safeRelationSchema || unsafeRelationSchema) { + // export schema with only scalar fields + const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`; + writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`); + createSchema = createScalarSchema; + + if (safeRelationSchema && unsafeRelationSchema) { + // build a union of with relation object fields and with fk fields (mutually exclusive) + + // TODO: we make all relation fields partial for now because in case of + // nested create, not all relation/fk fields are inside payload, need a + // better solution + createSchema = makeUnion( + makeMerge(createSchema, makePartial(safeRelationSchema)), + makeMerge(createSchema, makePartial(unsafeRelationSchema)) + ); + } else if (safeRelationSchema) { + // just relation + + // TODO: we make all relation fields partial for now because in case of + // nested create, not all relation/fk fields are inside payload, need a + // better solution + createSchema = makeMerge(createSchema, makePartial(safeRelationSchema)); + } + } + + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`); + createSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`); - // update schema - let updateSchema = 'baseSchema.partial()'; - if (refinements.length > 0) { - updateSchema = `refine(${updateSchema})`; + //////////////////////////////////////////////// + // 3. Update schema + //////////////////////////////////////////////// + let updateSchema = makePartial('baseSchema'); + + if (safeRelationSchema || unsafeRelationSchema) { + // export schema with only scalar fields + const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; + writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`); + updateSchema = updateScalarSchema; + + if (safeRelationSchema && unsafeRelationSchema) { + // build a union of with relation object fields and with fk fields (mutually exclusive) + updateSchema = makeUnion( + makeMerge(updateSchema, makePartial(safeRelationSchema)), + makeMerge(updateSchema, makePartial(unsafeRelationSchema)) + ); + } else if (safeRelationSchema) { + // just relation + updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema)); + } + } + + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`); + updateSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`); }); + return schemaName; } -function makePartial(schema: string, fields: string[]) { - return `${schema}.partial({ +function makePartial(schema: string, fields?: string[]) { + if (fields) { + return `${schema}.partial({ ${fields.map((f) => `${f}: true`).join(', ')}, })`; + } else { + return `${schema}.partial()`; + } } function makeOmit(schema: string, fields: string[]) { @@ -372,3 +510,19 @@ function makeOmit(schema: string, fields: string[]) { ${fields.map((f) => `${f}: true`).join(', ')}, })`; } + +function makeMerge(schema1: string, schema2: string): string { + return `${schema1}.merge(${schema2})`; +} + +function makeUnion(...schemas: string[]): string { + return `z.union([${schemas.join(', ')}])`; +} + +function hasForeignKey(field: DataModelField) { + const relAttr = getAttribute(field, '@relation'); + if (!relAttr) { + return false; + } + return !!getAttributeArg(relAttr, 'fields'); +} diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 0676c40d6..a73b34924 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,5 +1,5 @@ import { ExpressionContext, PluginError, getAttributeArg, getAttributeArgLiteral, getLiteral } from '@zenstackhq/sdk'; -import { DataModel, DataModelField, DataModelFieldAttribute, isEnum } from '@zenstackhq/sdk/ast'; +import { DataModel, DataModelField, DataModelFieldAttribute, isDataModel, isEnum } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; import { @@ -7,7 +7,23 @@ import { TypeScriptExpressionTransformerError, } from '../../../utils/typescript-expression-transformer'; -export function makeFieldSchema(field: DataModelField) { +export function makeFieldSchema(field: DataModelField, forMutation = false) { + if (isDataModel(field.type.reference?.ref)) { + if (!forMutation) { + // read schema, always optional + if (field.type.array) { + return `z.array(z.unknown()).optional()`; + } else { + return `z.record(z.unknown()).optional()`; + } + } else { + // write schema + return `${ + field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())' + }`; + } + } + let schema = makeZodSchema(field); const isDecimal = field.type.type === 'Decimal';