Skip to content

Commit 124a0a2

Browse files
authored
fix: generate foreign key field in zod schemas (#868)
1 parent bf85ceb commit 124a0a2

File tree

2 files changed

+188
-18
lines changed

2 files changed

+188
-18
lines changed

packages/schema/src/plugins/zod/generator.ts

Lines changed: 170 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import {
44
PluginOptions,
55
createProject,
66
emitProject,
7+
getAttribute,
8+
getAttributeArg,
79
getDataModels,
810
getLiteral,
911
getPrismaClientImportSpec,
@@ -15,7 +17,16 @@ import {
1517
resolvePath,
1618
saveProject,
1719
} from '@zenstackhq/sdk';
18-
import { DataModel, DataSource, EnumField, Model, isDataModel, isDataSource, isEnum } from '@zenstackhq/sdk/ast';
20+
import {
21+
DataModel,
22+
DataModelField,
23+
DataSource,
24+
EnumField,
25+
Model,
26+
isDataModel,
27+
isDataSource,
28+
isEnum,
29+
} from '@zenstackhq/sdk/ast';
1930
import { addMissingInputObjectTypes, resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers';
2031
import { promises as fs } from 'fs';
2132
import { streamAllContents } from 'langium';
@@ -262,10 +273,17 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
262273
sf.replaceWithText((writer) => {
263274
const fields = model.fields.filter(
264275
(field) =>
265-
// scalar fields only
276+
// regular fields only
266277
!isDataModel(field.type.reference?.ref) && !isForeignKeyField(field)
267278
);
268279

280+
const relations = model.fields.filter((field) => isDataModel(field.type.reference?.ref));
281+
const fkFields = model.fields.filter((field) => isForeignKeyField(field));
282+
// unsafe version of relations: including foreign keys and relation fields without fk
283+
const unsafeRelations = model.fields.filter(
284+
(field) => isForeignKeyField(field) || (isDataModel(field.type.reference?.ref) && !hasForeignKey(field))
285+
);
286+
269287
writer.writeLine('/* eslint-disable */');
270288
writer.writeLine(`import { z } from 'zod';`);
271289

@@ -302,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
302320
writer.writeLine(`import { Decimal } from 'decimal.js';`);
303321
}
304322

305-
// create base schema
323+
// base schema
306324
writer.write(`const baseSchema = z.object(`);
307325
writer.inlineBlock(() => {
308326
fields.forEach((field) => {
@@ -311,31 +329,92 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
311329
});
312330
writer.writeLine(');');
313331

332+
// relation fields
333+
334+
let allRelationSchema: string | undefined;
335+
let safeRelationSchema: string | undefined;
336+
let unsafeRelationSchema: string | undefined;
337+
338+
if (relations.length > 0 || fkFields.length > 0) {
339+
allRelationSchema = 'allRelationSchema';
340+
writer.write(`const ${allRelationSchema} = z.object(`);
341+
writer.inlineBlock(() => {
342+
[...relations, ...fkFields].forEach((field) => {
343+
writer.writeLine(`${field.name}: ${makeFieldSchema(field)},`);
344+
});
345+
});
346+
writer.writeLine(');');
347+
}
348+
349+
if (relations.length > 0) {
350+
safeRelationSchema = 'safeRelationSchema';
351+
writer.write(`const ${safeRelationSchema} = z.object(`);
352+
writer.inlineBlock(() => {
353+
relations.forEach((field) => {
354+
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
355+
});
356+
});
357+
writer.writeLine(');');
358+
}
359+
360+
if (unsafeRelations.length > 0) {
361+
unsafeRelationSchema = 'unsafeRelationSchema';
362+
writer.write(`const ${unsafeRelationSchema} = z.object(`);
363+
writer.inlineBlock(() => {
364+
unsafeRelations.forEach((field) => {
365+
writer.writeLine(`${field.name}: ${makeFieldSchema(field, true)},`);
366+
});
367+
});
368+
writer.writeLine(');');
369+
}
370+
314371
// compile "@@validate" to ".refine"
315372
const refinements = makeValidationRefinements(model);
373+
let refineFuncName: string | undefined;
316374
if (refinements.length > 0) {
375+
refineFuncName = `refine${upperCaseFirst(model.name)}`;
317376
writer.writeLine(
318-
`function refine<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
377+
`export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${refinements.join(
319378
'\n'
320379
)}; }`
321380
);
322381
}
323382

324-
// model schema
383+
////////////////////////////////////////////////
384+
// 1. Model schema
385+
////////////////////////////////////////////////
325386
let modelSchema = 'baseSchema';
387+
388+
// omit fields
326389
const fieldsToOmit = fields.filter((field) => hasAttribute(field, '@omit'));
327390
if (fieldsToOmit.length > 0) {
328391
modelSchema = makeOmit(
329392
modelSchema,
330393
fieldsToOmit.map((f) => f.name)
331394
);
332395
}
333-
if (refinements.length > 0) {
334-
modelSchema = `refine(${modelSchema})`;
396+
397+
if (allRelationSchema) {
398+
// export schema with only scalar fields
399+
const modelScalarSchema = `${upperCaseFirst(model.name)}ScalarSchema`;
400+
writer.writeLine(`export const ${modelScalarSchema} = ${modelSchema};`);
401+
modelSchema = modelScalarSchema;
402+
403+
// merge relations
404+
modelSchema = makeMerge(modelSchema, allRelationSchema);
405+
}
406+
407+
// refine
408+
if (refineFuncName) {
409+
const noRefineSchema = `${upperCaseFirst(model.name)}WithoutRefineSchema`;
410+
writer.writeLine(`export const ${noRefineSchema} = ${modelSchema};`);
411+
modelSchema = `${refineFuncName}(${noRefineSchema})`;
335412
}
336413
writer.writeLine(`export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};`);
337414

338-
// create schema
415+
////////////////////////////////////////////////
416+
// 2. Create schema
417+
////////////////////////////////////////////////
339418
let createSchema = 'baseSchema';
340419
const fieldsWithDefault = fields.filter(
341420
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
@@ -346,29 +425,104 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346425
fieldsWithDefault.map((f) => f.name)
347426
);
348427
}
349-
if (refinements.length > 0) {
350-
createSchema = `refine(${createSchema})`;
428+
429+
if (safeRelationSchema || unsafeRelationSchema) {
430+
// export schema with only scalar fields
431+
const createScalarSchema = `${upperCaseFirst(model.name)}CreateScalarSchema`;
432+
writer.writeLine(`export const ${createScalarSchema} = ${createSchema};`);
433+
createSchema = createScalarSchema;
434+
435+
if (safeRelationSchema && unsafeRelationSchema) {
436+
// build a union of with relation object fields and with fk fields (mutually exclusive)
437+
438+
// TODO: we make all relation fields partial for now because in case of
439+
// nested create, not all relation/fk fields are inside payload, need a
440+
// better solution
441+
createSchema = makeUnion(
442+
makeMerge(createSchema, makePartial(safeRelationSchema)),
443+
makeMerge(createSchema, makePartial(unsafeRelationSchema))
444+
);
445+
} else if (safeRelationSchema) {
446+
// just relation
447+
448+
// TODO: we make all relation fields partial for now because in case of
449+
// nested create, not all relation/fk fields are inside payload, need a
450+
// better solution
451+
createSchema = makeMerge(createSchema, makePartial(safeRelationSchema));
452+
}
453+
}
454+
455+
if (refineFuncName) {
456+
// export a schema without refinement for extensibility
457+
const noRefineSchema = `${upperCaseFirst(model.name)}CreateWithoutRefineSchema`;
458+
writer.writeLine(`export const ${noRefineSchema} = ${createSchema};`);
459+
createSchema = `${refineFuncName}(${noRefineSchema})`;
351460
}
352461
writer.writeLine(`export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};`);
353462

354-
// update schema
355-
let updateSchema = 'baseSchema.partial()';
356-
if (refinements.length > 0) {
357-
updateSchema = `refine(${updateSchema})`;
463+
////////////////////////////////////////////////
464+
// 3. Update schema
465+
////////////////////////////////////////////////
466+
let updateSchema = makePartial('baseSchema');
467+
468+
if (safeRelationSchema || unsafeRelationSchema) {
469+
// export schema with only scalar fields
470+
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;
471+
writer.writeLine(`export const ${updateScalarSchema} = ${updateSchema};`);
472+
updateSchema = updateScalarSchema;
473+
474+
if (safeRelationSchema && unsafeRelationSchema) {
475+
// build a union of with relation object fields and with fk fields (mutually exclusive)
476+
updateSchema = makeUnion(
477+
makeMerge(updateSchema, makePartial(safeRelationSchema)),
478+
makeMerge(updateSchema, makePartial(unsafeRelationSchema))
479+
);
480+
} else if (safeRelationSchema) {
481+
// just relation
482+
updateSchema = makeMerge(updateSchema, makePartial(safeRelationSchema));
483+
}
484+
}
485+
486+
if (refineFuncName) {
487+
// export a schema without refinement for extensibility
488+
const noRefineSchema = `${upperCaseFirst(model.name)}UpdateWithoutRefineSchema`;
489+
writer.writeLine(`export const ${noRefineSchema} = ${updateSchema};`);
490+
updateSchema = `${refineFuncName}(${noRefineSchema})`;
358491
}
359492
writer.writeLine(`export const ${upperCaseFirst(model.name)}UpdateSchema = ${updateSchema};`);
360493
});
494+
361495
return schemaName;
362496
}
363497

364-
function makePartial(schema: string, fields: string[]) {
365-
return `${schema}.partial({
498+
function makePartial(schema: string, fields?: string[]) {
499+
if (fields) {
500+
return `${schema}.partial({
366501
${fields.map((f) => `${f}: true`).join(', ')},
367502
})`;
503+
} else {
504+
return `${schema}.partial()`;
505+
}
368506
}
369507

370508
function makeOmit(schema: string, fields: string[]) {
371509
return `${schema}.omit({
372510
${fields.map((f) => `${f}: true`).join(', ')},
373511
})`;
374512
}
513+
514+
function makeMerge(schema1: string, schema2: string): string {
515+
return `${schema1}.merge(${schema2})`;
516+
}
517+
518+
function makeUnion(...schemas: string[]): string {
519+
return `z.union([${schemas.join(', ')}])`;
520+
}
521+
522+
function hasForeignKey(field: DataModelField) {
523+
const relAttr = getAttribute(field, '@relation');
524+
if (!relAttr) {
525+
return false;
526+
}
527+
return !!getAttributeArg(relAttr, 'fields');
528+
}

packages/schema/src/plugins/zod/utils/schema-gen.ts

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
import { ExpressionContext, PluginError, getAttributeArg, getAttributeArgLiteral, getLiteral } from '@zenstackhq/sdk';
2-
import { DataModel, DataModelField, DataModelFieldAttribute, isEnum } from '@zenstackhq/sdk/ast';
2+
import { DataModel, DataModelField, DataModelFieldAttribute, isDataModel, isEnum } from '@zenstackhq/sdk/ast';
33
import { upperCaseFirst } from 'upper-case-first';
44
import { name } from '..';
55
import {
66
TypeScriptExpressionTransformer,
77
TypeScriptExpressionTransformerError,
88
} from '../../../utils/typescript-expression-transformer';
99

10-
export function makeFieldSchema(field: DataModelField) {
10+
export function makeFieldSchema(field: DataModelField, forMutation = false) {
11+
if (isDataModel(field.type.reference?.ref)) {
12+
if (!forMutation) {
13+
// read schema, always optional
14+
if (field.type.array) {
15+
return `z.array(z.unknown()).optional()`;
16+
} else {
17+
return `z.record(z.unknown()).optional()`;
18+
}
19+
} else {
20+
// write schema
21+
return `${
22+
field.type.optional || field.type.array ? 'z.record(z.unknown()).optional()' : 'z.record(z.unknown())'
23+
}`;
24+
}
25+
}
26+
1127
let schema = makeZodSchema(field);
1228
const isDecimal = field.type.type === 'Decimal';
1329

0 commit comments

Comments
 (0)