From 3cd7908bb34e0dc2748fc0b56be6d0cf300a5ea5 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 28 Nov 2023 23:33:14 -0800 Subject: [PATCH 1/3] fix: generate foreign key field in zod schemas --- packages/schema/src/plugins/zod/generator.ts | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 1f3687166..ba0df3876 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -9,7 +9,6 @@ import { getPrismaClientImportSpec, hasAttribute, isEnumFieldReference, - isForeignKeyField, isFromStdlib, parseOptionAsStrings, resolvePath, @@ -263,7 +262,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s const fields = model.fields.filter( (field) => // scalar fields only - !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) + !isDataModel(field.type.reference?.ref) ); writer.writeLine('/* eslint-disable */'); From e572ce30626db76dc932a657e3062d75458364e1 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 29 Nov 2023 12:05:01 -0800 Subject: [PATCH 2/3] update --- packages/schema/src/plugins/zod/generator.ts | 181 ++++++++++++++++-- .../src/plugins/zod/utils/schema-gen.ts | 20 +- 2 files changed, 182 insertions(+), 19 deletions(-) diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index ba0df3876..455053c72 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -4,17 +4,29 @@ import { PluginOptions, createProject, emitProject, + getAttribute, + getAttributeArg, getDataModels, getLiteral, getPrismaClientImportSpec, hasAttribute, isEnumFieldReference, + isForeignKeyField, isFromStdlib, parseOptionAsStrings, resolvePath, saveProject, } from '@zenstackhq/sdk'; -import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast'; +import { + DataModel, + DataModelField, + DataSource, + EnumField, + Model, + isDataModel, + isDataSource, + isEnum, +} from '@zenstackhq/sdk/ast'; import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers'; import { promises as fs } from 'fs'; import { streamAllContents } from 'langium'; @@ -261,8 +273,15 @@ async function generateModelSchema(model: DataModel, project: Project, output: s sf.replaceWithText((writer) => { const fields = model.fields.filter( (field) => - // scalar fields only - !isDataModel(field.type.reference?.ref) + // regular fields only + !isDataModel(field.type.reference?.ref) && !isForeignKeyField(field) + ); + + const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref)); + const fkFields = model.fields.filter((field) => isForeignKeyField(field)); + // unsafe version of relations: including foreign keys and relation fields without fk + const unsafeRelations = model.fields.filter( + (field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field)) ); writer.writeLine('/* eslint-disable */'); @@ -301,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s writer.writeLine(`import { Decimal } from 'decimal.js';`); } - // create base schema + // base schema writer.write(`const baseSchema = z.object(`); writer.inlineBlock(() => { fields.forEach((field) => { @@ -310,18 +329,63 @@ async function generateModelSchema(model: DataModel, project: Project, output: s }); writer.writeLine(');'); + // relation fields + + let allRelationSchema: string | undefined; + let safeRelationSchema: string | undefined; + let unsafeRelationSchema: string | undefined; + + if (relations.length > 0 || fkFields.length > 0) { + allRelationSchema = 'allRelationSchema'; + writer.write(`const ${allRelationSchema} = z.object(`); + writer.inlineBlock(() => { + [...relations, ...fkFields].forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`); + }); + }); + writer.writeLine(');'); + } + + if (relations.length > 0) { + safeRelationSchema = 'safeRelationSchema'; + writer.write(`const ${safeRelationSchema} = z.object(`); + writer.inlineBlock(() => { + relations.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); + }); + }); + writer.writeLine(');'); + } + + if (unsafeRelations.length > 0) { + unsafeRelationSchema = 'unsafeRelationSchema'; + writer.write(`const ${unsafeRelationSchema} = z.object(`); + writer.inlineBlock(() => { + unsafeRelations.forEach((field) => { + writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`); + }); + }); + writer.writeLine(');'); + } + // compile "@@validate" to ".refine" const refinements = makeValidationRefinements(model); + let refineFuncName: string | undefined; if (refinements.length > 0) { + refineFuncName = `refine${upperCaseFirst(model.name)}`; writer.writeLine( - `function refine(schema: z.ZodType) { return schema${refinements.join( + `export function ${refineFuncName}(schema: z.ZodType) { return schema${refinements.join( '\n' )}; }` ); } - // model schema + //////////////////////////////////////////////// + // 1. Model schema + //////////////////////////////////////////////// let modelSchema = 'baseSchema'; + + // omit fields const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit')); if (fieldsToOmit.length > 0) { modelSchema = makeOmit( @@ -329,12 +393,28 @@ async function generateModelSchema(model: DataModel, project: Project, output: s fieldsToOmit.map((f) => f.name) ); } - if (refinements.length > 0) { - modelSchema = `refine(${modelSchema})`; + + if (allRelationSchema) { + // export schema with only scalar fields + const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`; + writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`); + modelSchema = modelScalarSchema; + + // merge relations + modelSchema = makeMerge(modelSchema, allRelationSchema); + } + + // refine + if (refineFuncName) { + const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`); + modelSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`); - // create schema + //////////////////////////////////////////////// + // 2. Create schema + //////////////////////////////////////////////// let createSchema = 'baseSchema'; const fieldsWithDefault = fields.filter( (field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array @@ -345,25 +425,76 @@ async function generateModelSchema(model: DataModel, project: Project, output: s fieldsWithDefault.map((f) => f.name) ); } - if (refinements.length > 0) { - createSchema = `refine(${createSchema})`; + + if (safeRelationSchema || unsafeRelationSchema) { + // export schema with only scalar fields + const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`; + writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`); + createSchema = createScalarSchema; + + if (safeRelationSchema && unsafeRelationSchema) { + // build a union of with relation object fields and with fk fields (mutually exclusive) + createSchema = makeUnion( + makeMerge(createSchema, safeRelationSchema), + makeMerge(createSchema, unsafeRelationSchema) + ); + } else if (safeRelationSchema) { + // just relation + createSchema = makeMerge(createSchema, safeRelationSchema); + } + } + + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`); + createSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`); - // update schema - let updateSchema = 'baseSchema.partial()'; - if (refinements.length > 0) { - updateSchema = `refine(${updateSchema})`; + //////////////////////////////////////////////// + // 3. Update schema + //////////////////////////////////////////////// + let updateSchema = makePartial('baseSchema'); + + if (safeRelationSchema || unsafeRelationSchema) { + // export schema with only scalar fields + const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`; + writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`); + updateSchema = updateScalarSchema; + + if (safeRelationSchema && unsafeRelationSchema) { + // build a union of with relation object fields and with fk fields (mutually exclusive) + updateSchema = makeUnion( + makeMerge(updateSchema, makePartial(safeRelationSchema)), + makeMerge(updateSchema, makePartial(unsafeRelationSchema)) + ); + } else if (safeRelationSchema) { + // just relation + updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema)); + } + } + + if (refineFuncName) { + // export a schema without refinement for extensibility + const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`; + writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`); + updateSchema = `${refineFuncName}(${noRefineSchema})`; } writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`); }); + return schemaName; } -function makePartial(schema: string, fields: string[]) { - return `${schema}.partial({ +function makePartial(schema: string, fields?: string[]) { + if (fields) { + return `${schema}.partial({ ${fields.map((f) => `${f}: true`).join(', ')}, })`; + } else { + return `${schema}.partial()`; + } } function makeOmit(schema: string, fields: string[]) { @@ -371,3 +502,19 @@ function makeOmit(schema: string, fields: string[]) { ${fields.map((f) => `${f}: true`).join(', ')}, })`; } + +function makeMerge(schema1: string, schema2: string): string { + return `${schema1}.merge(${schema2})`; +} + +function makeUnion(...schemas: string[]): string { + return `z.union([${schemas.join(', ')}])`; +} + +function hasForeignKey(field: DataModelField) { + const relAttr = getAttribute(field, '@relation'); + if (!relAttr) { + return false; + } + return !!getAttributeArg(relAttr, 'fields'); +} diff --git a/packages/schema/src/plugins/zod/utils/schema-gen.ts b/packages/schema/src/plugins/zod/utils/schema-gen.ts index 0676c40d6..a73b34924 100644 --- a/packages/schema/src/plugins/zod/utils/schema-gen.ts +++ b/packages/schema/src/plugins/zod/utils/schema-gen.ts @@ -1,5 +1,5 @@ import { ExpressionContext, PluginError, getAttributeArg, getAttributeArgLiteral, getLiteral } from '@zenstackhq/sdk'; -import { DataModel, DataModelField, DataModelFieldAttribute, isEnum } from '@zenstackhq/sdk/ast'; +import { DataModel, DataModelField, DataModelFieldAttribute, isDataModel, isEnum } from '@zenstackhq/sdk/ast'; import { upperCaseFirst } from 'upper-case-first'; import { name } from '..'; import { @@ -7,7 +7,23 @@ import { TypeScriptExpressionTransformerError, } from '../../../utils/typescript-expression-transformer'; -export function makeFieldSchema(field: DataModelField) { +export function makeFieldSchema(field: DataModelField, forMutation = false) { + if (isDataModel(field.type.reference?.ref)) { + if (!forMutation) { + // read schema, always optional + if (field.type.array) { + return `z.array(z.unknown()).optional()`; + } else { + return `z.record(z.unknown()).optional()`; + } + } else { + // write schema + return `${ + field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())' + }`; + } + } + let schema = makeZodSchema(field); const isDecimal = field.type.type === 'Decimal'; From b2a37ba6b820514e0fdaebf78d9a39f77009a47d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Wed, 29 Nov 2023 12:26:42 -0800 Subject: [PATCH 3/3] fix --- packages/schema/src/plugins/zod/generator.ts | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/packages/schema/src/plugins/zod/generator.ts b/packages/schema/src/plugins/zod/generator.ts index 455053c72..04c903577 100644 --- a/packages/schema/src/plugins/zod/generator.ts +++ b/packages/schema/src/plugins/zod/generator.ts @@ -434,13 +434,21 @@ async function generateModelSchema(model: DataModel, project: Project, output: s if (safeRelationSchema && unsafeRelationSchema) { // build a union of with relation object fields and with fk fields (mutually exclusive) + + // TODO: we make all relation fields partial for now because in case of + // nested create, not all relation/fk fields are inside payload, need a + // better solution createSchema = makeUnion( - makeMerge(createSchema, safeRelationSchema), - makeMerge(createSchema, unsafeRelationSchema) + makeMerge(createSchema, makePartial(safeRelationSchema)), + makeMerge(createSchema, makePartial(unsafeRelationSchema)) ); } else if (safeRelationSchema) { // just relation - createSchema = makeMerge(createSchema, safeRelationSchema); + + // TODO: we make all relation fields partial for now because in case of + // nested create, not all relation/fk fields are inside payload, need a + // better solution + createSchema = makeMerge(createSchema, makePartial(safeRelationSchema)); } }