4
4
PluginOptions ,
5
5
createProject ,
6
6
emitProject ,
7
+ getAttribute ,
8
+ getAttributeArg ,
7
9
getDataModels ,
8
10
getLiteral ,
9
11
getPrismaClientImportSpec ,
@@ -15,7 +17,16 @@ import {
15
17
resolvePath ,
16
18
saveProject ,
17
19
} from '@zenstackhq/sdk' ;
18
- import { DataModel , DataSource , EnumField , Model , isDataModel , isDataSource , isEnum } from '@zenstackhq/sdk/ast' ;
20
+ import {
21
+ DataModel ,
22
+ DataModelField ,
23
+ DataSource ,
24
+ EnumField ,
25
+ Model ,
26
+ isDataModel ,
27
+ isDataSource ,
28
+ isEnum ,
29
+ } from '@zenstackhq/sdk/ast' ;
19
30
import { addMissingInputObjectTypes , resolveAggregateOperationSupport } from '@zenstackhq/sdk/dmmf-helpers' ;
20
31
import { promises as fs } from 'fs' ;
21
32
import { streamAllContents } from 'langium' ;
@@ -262,10 +273,17 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
262
273
sf . replaceWithText ( ( writer ) => {
263
274
const fields = model . fields . filter (
264
275
( field ) =>
265
- // scalar fields only
276
+ // regular fields only
266
277
! isDataModel ( field . type . reference ?. ref ) && ! isForeignKeyField ( field )
267
278
) ;
268
279
280
+ const relations = model . fields . filter ( ( field ) => isDataModel ( field . type . reference ?. ref ) ) ;
281
+ const fkFields = model . fields . filter ( ( field ) => isForeignKeyField ( field ) ) ;
282
+ // unsafe version of relations: including foreign keys and relation fields without fk
283
+ const unsafeRelations = model . fields . filter (
284
+ ( field ) => isForeignKeyField ( field ) || ( isDataModel ( field . type . reference ?. ref ) && ! hasForeignKey ( field ) )
285
+ ) ;
286
+
269
287
writer . writeLine ( '/* eslint-disable */' ) ;
270
288
writer . writeLine ( `import { z } from 'zod';` ) ;
271
289
@@ -302,7 +320,7 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
302
320
writer . writeLine ( `import { Decimal } from 'decimal.js';` ) ;
303
321
}
304
322
305
- // create base schema
323
+ // base schema
306
324
writer . write ( `const baseSchema = z.object(` ) ;
307
325
writer . inlineBlock ( ( ) => {
308
326
fields . forEach ( ( field ) => {
@@ -311,31 +329,92 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
311
329
} ) ;
312
330
writer . writeLine ( ');' ) ;
313
331
332
+ // relation fields
333
+
334
+ let allRelationSchema : string | undefined ;
335
+ let safeRelationSchema : string | undefined ;
336
+ let unsafeRelationSchema : string | undefined ;
337
+
338
+ if ( relations . length > 0 || fkFields . length > 0 ) {
339
+ allRelationSchema = 'allRelationSchema' ;
340
+ writer . write ( `const ${ allRelationSchema } = z.object(` ) ;
341
+ writer . inlineBlock ( ( ) => {
342
+ [ ...relations , ...fkFields ] . forEach ( ( field ) => {
343
+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field ) } ,` ) ;
344
+ } ) ;
345
+ } ) ;
346
+ writer . writeLine ( ');' ) ;
347
+ }
348
+
349
+ if ( relations . length > 0 ) {
350
+ safeRelationSchema = 'safeRelationSchema' ;
351
+ writer . write ( `const ${ safeRelationSchema } = z.object(` ) ;
352
+ writer . inlineBlock ( ( ) => {
353
+ relations . forEach ( ( field ) => {
354
+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
355
+ } ) ;
356
+ } ) ;
357
+ writer . writeLine ( ');' ) ;
358
+ }
359
+
360
+ if ( unsafeRelations . length > 0 ) {
361
+ unsafeRelationSchema = 'unsafeRelationSchema' ;
362
+ writer . write ( `const ${ unsafeRelationSchema } = z.object(` ) ;
363
+ writer . inlineBlock ( ( ) => {
364
+ unsafeRelations . forEach ( ( field ) => {
365
+ writer . writeLine ( `${ field . name } : ${ makeFieldSchema ( field , true ) } ,` ) ;
366
+ } ) ;
367
+ } ) ;
368
+ writer . writeLine ( ');' ) ;
369
+ }
370
+
314
371
// compile "@@validate" to ".refine"
315
372
const refinements = makeValidationRefinements ( model ) ;
373
+ let refineFuncName : string | undefined ;
316
374
if ( refinements . length > 0 ) {
375
+ refineFuncName = `refine${ upperCaseFirst ( model . name ) } ` ;
317
376
writer . writeLine (
318
- `function refine <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
377
+ `export function ${ refineFuncName } <T, D extends z.ZodTypeDef>(schema: z.ZodType<T, D, T>) { return schema${ refinements . join (
319
378
'\n'
320
379
) } ; }`
321
380
) ;
322
381
}
323
382
324
- // model schema
383
+ ////////////////////////////////////////////////
384
+ // 1. Model schema
385
+ ////////////////////////////////////////////////
325
386
let modelSchema = 'baseSchema' ;
387
+
388
+ // omit fields
326
389
const fieldsToOmit = fields . filter ( ( field ) => hasAttribute ( field , '@omit' ) ) ;
327
390
if ( fieldsToOmit . length > 0 ) {
328
391
modelSchema = makeOmit (
329
392
modelSchema ,
330
393
fieldsToOmit . map ( ( f ) => f . name )
331
394
) ;
332
395
}
333
- if ( refinements . length > 0 ) {
334
- modelSchema = `refine(${ modelSchema } )` ;
396
+
397
+ if ( allRelationSchema ) {
398
+ // export schema with only scalar fields
399
+ const modelScalarSchema = `${ upperCaseFirst ( model . name ) } ScalarSchema` ;
400
+ writer . writeLine ( `export const ${ modelScalarSchema } = ${ modelSchema } ;` ) ;
401
+ modelSchema = modelScalarSchema ;
402
+
403
+ // merge relations
404
+ modelSchema = makeMerge ( modelSchema , allRelationSchema ) ;
405
+ }
406
+
407
+ // refine
408
+ if ( refineFuncName ) {
409
+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } WithoutRefineSchema` ;
410
+ writer . writeLine ( `export const ${ noRefineSchema } = ${ modelSchema } ;` ) ;
411
+ modelSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
335
412
}
336
413
writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } Schema = ${ modelSchema } ;` ) ;
337
414
338
- // create schema
415
+ ////////////////////////////////////////////////
416
+ // 2. Create schema
417
+ ////////////////////////////////////////////////
339
418
let createSchema = 'baseSchema' ;
340
419
const fieldsWithDefault = fields . filter (
341
420
( field ) => hasAttribute ( field , '@default' ) || hasAttribute ( field , '@updatedAt' ) || field . type . array
@@ -346,29 +425,104 @@ async function generateModelSchema(model: DataModel, project: Project, output: s
346
425
fieldsWithDefault . map ( ( f ) => f . name )
347
426
) ;
348
427
}
349
- if ( refinements . length > 0 ) {
350
- createSchema = `refine(${ createSchema } )` ;
428
+
429
+ if ( safeRelationSchema || unsafeRelationSchema ) {
430
+ // export schema with only scalar fields
431
+ const createScalarSchema = `${ upperCaseFirst ( model . name ) } CreateScalarSchema` ;
432
+ writer . writeLine ( `export const ${ createScalarSchema } = ${ createSchema } ;` ) ;
433
+ createSchema = createScalarSchema ;
434
+
435
+ if ( safeRelationSchema && unsafeRelationSchema ) {
436
+ // build a union of with relation object fields and with fk fields (mutually exclusive)
437
+
438
+ // TODO: we make all relation fields partial for now because in case of
439
+ // nested create, not all relation/fk fields are inside payload, need a
440
+ // better solution
441
+ createSchema = makeUnion (
442
+ makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ,
443
+ makeMerge ( createSchema , makePartial ( unsafeRelationSchema ) )
444
+ ) ;
445
+ } else if ( safeRelationSchema ) {
446
+ // just relation
447
+
448
+ // TODO: we make all relation fields partial for now because in case of
449
+ // nested create, not all relation/fk fields are inside payload, need a
450
+ // better solution
451
+ createSchema = makeMerge ( createSchema , makePartial ( safeRelationSchema ) ) ;
452
+ }
453
+ }
454
+
455
+ if ( refineFuncName ) {
456
+ // export a schema without refinement for extensibility
457
+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } CreateWithoutRefineSchema` ;
458
+ writer . writeLine ( `export const ${ noRefineSchema } = ${ createSchema } ;` ) ;
459
+ createSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
351
460
}
352
461
writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } CreateSchema = ${ createSchema } ;` ) ;
353
462
354
- // update schema
355
- let updateSchema = 'baseSchema.partial()' ;
356
- if ( refinements . length > 0 ) {
357
- updateSchema = `refine(${ updateSchema } )` ;
463
+ ////////////////////////////////////////////////
464
+ // 3. Update schema
465
+ ////////////////////////////////////////////////
466
+ let updateSchema = makePartial ( 'baseSchema' ) ;
467
+
468
+ if ( safeRelationSchema || unsafeRelationSchema ) {
469
+ // export schema with only scalar fields
470
+ const updateScalarSchema = `${ upperCaseFirst ( model . name ) } UpdateScalarSchema` ;
471
+ writer . writeLine ( `export const ${ updateScalarSchema } = ${ updateSchema } ;` ) ;
472
+ updateSchema = updateScalarSchema ;
473
+
474
+ if ( safeRelationSchema && unsafeRelationSchema ) {
475
+ // build a union of with relation object fields and with fk fields (mutually exclusive)
476
+ updateSchema = makeUnion (
477
+ makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ,
478
+ makeMerge ( updateSchema , makePartial ( unsafeRelationSchema ) )
479
+ ) ;
480
+ } else if ( safeRelationSchema ) {
481
+ // just relation
482
+ updateSchema = makeMerge ( updateSchema , makePartial ( safeRelationSchema ) ) ;
483
+ }
484
+ }
485
+
486
+ if ( refineFuncName ) {
487
+ // export a schema without refinement for extensibility
488
+ const noRefineSchema = `${ upperCaseFirst ( model . name ) } UpdateWithoutRefineSchema` ;
489
+ writer . writeLine ( `export const ${ noRefineSchema } = ${ updateSchema } ;` ) ;
490
+ updateSchema = `${ refineFuncName } (${ noRefineSchema } )` ;
358
491
}
359
492
writer . writeLine ( `export const ${ upperCaseFirst ( model . name ) } UpdateSchema = ${ updateSchema } ;` ) ;
360
493
} ) ;
494
+
361
495
return schemaName ;
362
496
}
363
497
364
- function makePartial ( schema : string , fields : string [ ] ) {
365
- return `${ schema } .partial({
498
+ function makePartial ( schema : string , fields ?: string [ ] ) {
499
+ if ( fields ) {
500
+ return `${ schema } .partial({
366
501
${ fields . map ( ( f ) => `${ f } : true` ) . join ( ', ' ) } ,
367
502
})` ;
503
+ } else {
504
+ return `${ schema } .partial()` ;
505
+ }
368
506
}
369
507
370
508
function makeOmit ( schema : string , fields : string [ ] ) {
371
509
return `${ schema } .omit({
372
510
${ fields . map ( ( f ) => `${ f } : true` ) . join ( ', ' ) } ,
373
511
})` ;
374
512
}
513
+
514
+ function makeMerge ( schema1 : string , schema2 : string ) : string {
515
+ return `${ schema1 } .merge(${ schema2 } )` ;
516
+ }
517
+
518
+ function makeUnion ( ...schemas : string [ ] ) : string {
519
+ return `z.union([${ schemas . join ( ', ' ) } ])` ;
520
+ }
521
+
522
+ function hasForeignKey ( field : DataModelField ) {
523
+ const relAttr = getAttribute ( field , '@relation' ) ;
524
+ if ( ! relAttr ) {
525
+ return false ;
526
+ }
527
+ return ! ! getAttributeArg ( relAttr , 'fields' ) ;
528
+ }
0 commit comments