Skip to content

Commit 17fe8c3

Browse files
authored
fix(zod): zod create/update schemas should exclude discriminator fields (#1609)
1 parent 91abbb8 commit 17fe8c3

File tree

4 files changed

+148
-40
lines changed

4 files changed

+148
-40
lines changed

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import {
77
getDataModels,
88
getLiteral,
99
isDelegateModel,
10+
isDiscriminatorField,
1011
type PluginOptions,
1112
} from '@zenstackhq/sdk';
1213
import {
@@ -495,33 +496,34 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
495496
return source;
496497
}
497498

499+
private readonly ModelCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;
500+
498501
private removeDiscriminatorFromConcreteInput(
499502
typeAlias: TypeAliasDeclaration,
500-
delegateInfo: DelegateInfo,
503+
_delegateInfo: DelegateInfo,
501504
source: string
502505
) {
503-
// remove discriminator field from the create/update input of concrete models because
504-
// discriminator cannot be set directly
506+
// remove discriminator field from the create/update input because discriminator cannot be set directly
505507
const typeName = typeAlias.getName();
506-
const concreteModelNames = delegateInfo.map(([, concretes]) => concretes.map((c) => c.name)).flatMap((c) => c);
507-
const concreteCreateUpdateInputRegex = new RegExp(
508-
`(${concreteModelNames.join('|')})(Unchecked)?(Create|Update).*Input`
509-
);
510508

511-
const match = typeName.match(concreteCreateUpdateInputRegex);
509+
const match = typeName.match(this.ModelCreateUpdateInputRegex);
512510
if (match) {
513511
const modelName = match[1];
514-
const record = delegateInfo.find(([, concretes]) => concretes.some((c) => c.name === modelName));
515-
if (record) {
516-
// remove all discriminator fields recursively
517-
const delegateOfConcrete = record[0];
518-
const discriminators = this.getDiscriminatorFieldsRecursively(delegateOfConcrete);
519-
discriminators.forEach((discriminatorDecl) => {
520-
const discriminatorNode = this.findNamedProperty(typeAlias, discriminatorDecl.name);
521-
if (discriminatorNode) {
522-
source = source.replace(discriminatorNode.getText(), '');
512+
const dataModel = this.model.declarations.find(
513+
(d): d is DataModel => isDataModel(d) && d.name === modelName
514+
);
515+
516+
if (!dataModel) {
517+
return source;
518+
}
519+
520+
for (const field of dataModel.fields) {
521+
if (isDiscriminatorField(field)) {
522+
const fieldDef = this.findNamedProperty(typeAlias, field.name);
523+
if (fieldDef) {
524+
source = source.replace(fieldDef.getText(), '');
523525
}
524-
});
526+
}
525527
}
526528
}
527529
return source;
@@ -618,22 +620,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
618620
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
619621
}
620622

621-
private getDiscriminatorFieldsRecursively(delegate: DataModel, result: DataModelField[] = []) {
622-
if (isDelegateModel(delegate)) {
623-
const discriminator = this.getDiscriminatorField(delegate);
624-
if (discriminator) {
625-
result.push(discriminator);
626-
}
627-
628-
for (const superType of delegate.superTypes) {
629-
if (superType.ref) {
630-
result.push(...this.getDiscriminatorFieldsRecursively(superType.ref, result));
631-
}
632-
}
633-
}
634-
return result;
635-
}
636-
637623
private async saveSourceFile(sf: SourceFile) {
638624
if (this.options.preserveTsFiles) {
639625
await sf.save();

packages/schema/src/plugins/zod/generator.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
ensureEmptyDir,
66
getDataModels,
77
hasAttribute,
8+
isDiscriminatorField,
89
isEnumFieldReference,
910
isForeignKeyField,
1011
isFromStdlib,
@@ -368,6 +369,13 @@ export function ${refineFuncName}<T, D extends z.ZodTypeDef>(schema: z.ZodType<T
368369
);
369370
}
370371

372+
// delegate discriminator fields are to be excluded from mutation schemas
373+
const delegateFields = model.fields.filter((field) => isDiscriminatorField(field));
374+
const omitDiscriminators =
375+
delegateFields.length > 0
376+
? `.omit({ ${delegateFields.map((f) => `${f.name}: true`).join(', ')} })`
377+
: '';
378+
371379
////////////////////////////////////////////////
372380
// 1. Model schema
373381
////////////////////////////////////////////////
@@ -429,7 +437,7 @@ export const ${upperCaseFirst(model.name)}Schema = ${modelSchema};
429437
////////////////////////////////////////////////
430438

431439
// schema for validating prisma create input (all fields optional)
432-
let prismaCreateSchema = this.makePassthrough(this.makePartial('baseSchema'));
440+
let prismaCreateSchema = this.makePassthrough(this.makePartial(`baseSchema${omitDiscriminators}`));
433441
if (refineFuncName) {
434442
prismaCreateSchema = `${refineFuncName}(${prismaCreateSchema})`;
435443
}
@@ -445,6 +453,7 @@ export const ${upperCaseFirst(model.name)}PrismaCreateSchema = ${prismaCreateSch
445453
// note numeric fields can be simple update or atomic operations
446454
let prismaUpdateSchema = `z.object({
447455
${scalarFields
456+
.filter((f) => !isDiscriminatorField(f))
448457
.map((field) => {
449458
let fieldSchema = makeFieldSchema(field);
450459
if (field.type.type === 'Int' || field.type.type === 'Float') {
@@ -472,7 +481,7 @@ export const ${upperCaseFirst(model.name)}PrismaUpdateSchema = ${prismaUpdateSch
472481
// 3. Create schema
473482
////////////////////////////////////////////////
474483

475-
let createSchema = 'baseSchema';
484+
let createSchema = `baseSchema${omitDiscriminators}`;
476485
const fieldsWithDefault = scalarFields.filter(
477486
(field) => hasAttribute(field, '@default') || hasAttribute(field, '@updatedAt') || field.type.array
478487
);
@@ -524,7 +533,7 @@ export const ${upperCaseFirst(model.name)}CreateSchema = ${createSchema};
524533
////////////////////////////////////////////////
525534

526535
// for update all fields are optional
527-
let updateSchema = this.makePartial('baseSchema');
536+
let updateSchema = this.makePartial(`baseSchema${omitDiscriminators}`);
528537

529538
// export schema with only scalar fields: `[Model]UpdateScalarSchema`
530539
const updateScalarSchema = `${upperCaseFirst(model.name)}UpdateScalarSchema`;

packages/schema/src/plugins/zod/transformer.ts

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/* eslint-disable @typescript-eslint/ban-ts-comment */
2-
import { indentString, type PluginOptions } from '@zenstackhq/sdk';
3-
import type { Model } from '@zenstackhq/sdk/ast';
2+
import { indentString, isDiscriminatorField, type PluginOptions } from '@zenstackhq/sdk';
3+
import { DataModel, isDataModel, type Model } from '@zenstackhq/sdk/ast';
44
import { checkModelHasModelRelation, findModelByName, isAggregateInputType } from '@zenstackhq/sdk/dmmf-helpers';
55
import { supportCreateMany, type DMMF as PrismaDMMF } from '@zenstackhq/sdk/prisma';
66
import path from 'path';
@@ -90,8 +90,31 @@ export default class Transformer {
9090
return `${this.name}.schema`;
9191
}
9292

93+
private delegateCreateUpdateInputRegex = /(\S+)(Unchecked)?(Create|Update).*Input/;
94+
9395
generateObjectSchemaFields(generateUnchecked: boolean) {
94-
const zodObjectSchemaFields = this.fields
96+
let fields = this.fields;
97+
98+
// exclude discriminator fields from create/update input schemas
99+
const createUpdateMatch = this.delegateCreateUpdateInputRegex.exec(this.name);
100+
if (createUpdateMatch) {
101+
const modelName = createUpdateMatch[1];
102+
const dataModel = this.zmodel.declarations.find(
103+
(d): d is DataModel => isDataModel(d) && d.name === modelName
104+
);
105+
if (dataModel) {
106+
const discriminatorFields = dataModel.fields.filter(isDiscriminatorField);
107+
if (discriminatorFields.length > 0) {
108+
fields = fields.filter((field) => {
109+
return !discriminatorFields.some(
110+
(discriminatorField) => discriminatorField.name === field.name
111+
);
112+
});
113+
}
114+
}
115+
}
116+
117+
const zodObjectSchemaFields = fields
95118
.map((field) => this.generateObjectSchemaField(field, generateUnchecked))
96119
.flatMap((item) => item)
97120
.map((item) => {

tests/integration/tests/enhancements/with-delegate/plugin-interaction.test.ts

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,94 @@ describe('Polymorphic Plugin Interaction Test', () => {
5555
extraDependencies: ['@trpc/client', '@trpc/server', '@trpc/react-query'],
5656
});
5757
});
58+
59+
it('zod', async () => {
60+
const { zodSchemas } = await loadSchema(POLYMORPHIC_SCHEMA, { fullZod: true });
61+
62+
// model schema
63+
expect(
64+
zodSchemas.models.AssetSchema.parse({
65+
id: 1,
66+
assetType: 'video',
67+
createdAt: new Date(),
68+
viewCount: 100,
69+
})
70+
).toBeTruthy();
71+
72+
expect(
73+
zodSchemas.models.AssetSchema.parse({
74+
id: 1,
75+
assetType: 'video',
76+
createdAt: new Date(),
77+
viewCount: 100,
78+
videoType: 'ratedVideo', // should be stripped
79+
}).videoType
80+
).toBeUndefined();
81+
82+
expect(
83+
zodSchemas.models.VideoSchema.parse({
84+
id: 1,
85+
assetType: 'video',
86+
videoType: 'ratedVideo',
87+
duration: 100,
88+
url: 'http://example.com',
89+
createdAt: new Date(),
90+
viewCount: 100,
91+
})
92+
).toBeTruthy();
93+
94+
expect(() =>
95+
zodSchemas.models.VideoSchema.parse({
96+
id: 1,
97+
assetType: 'video',
98+
videoType: 'ratedVideo',
99+
url: 'http://example.com',
100+
createdAt: new Date(),
101+
viewCount: 100,
102+
})
103+
).toThrow('duration');
104+
105+
// create schema
106+
expect(
107+
zodSchemas.models.VideoCreateSchema.parse({
108+
duration: 100,
109+
url: 'http://example.com',
110+
}).assetType // discriminator should not be set
111+
).toBeUndefined();
112+
113+
// update schema
114+
expect(
115+
zodSchemas.models.VideoUpdateSchema.parse({
116+
duration: 100,
117+
url: 'http://example.com',
118+
}).assetType // discriminator should not be set
119+
).toBeUndefined();
120+
121+
// prisma create schema
122+
expect(
123+
zodSchemas.models.VideoPrismaCreateSchema.strip().parse({
124+
assetType: 'video',
125+
}).assetType // discriminator should not be set
126+
).toBeUndefined();
127+
128+
// input object schema
129+
expect(
130+
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
131+
duration: 100,
132+
viewCount: 200,
133+
url: 'http://www.example.com',
134+
rating: 5,
135+
})
136+
).toBeTruthy();
137+
138+
expect(() =>
139+
zodSchemas.objects.RatedVideoCreateInputObjectSchema.parse({
140+
duration: 100,
141+
viewCount: 200,
142+
url: 'http://www.example.com',
143+
rating: 5,
144+
videoType: 'ratedVideo',
145+
})
146+
).toThrow('videoType');
147+
});
58148
});

0 commit comments

Comments
 (0)