Skip to content

Commit a5d15a3

Browse files
authored
feat: support configuring what models to include for zod and trpc plugins (#747)
1 parent 30b95eb commit a5d15a3

File tree

6 files changed

+389
-41
lines changed

6 files changed

+389
-41
lines changed

packages/plugins/trpc/src/generator.ts

Lines changed: 14 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import {
55
PluginOptions,
66
RUNTIME_PACKAGE,
77
getPrismaClientImportSpec,
8+
parseOptionAsStrings,
89
requireOption,
910
resolvePath,
1011
saveProject,
@@ -32,11 +33,14 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
3233
let outDir = requireOption<string>(options, 'output');
3334
outDir = resolvePath(outDir, options);
3435

36+
// resolve "generateModels" option
37+
const generateModels = parseOptionAsStrings(options, 'generateModels', name);
38+
3539
// resolve "generateModelActions" option
36-
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions');
40+
const generateModelActions = parseOptionAsStrings(options, 'generateModelActions', name);
3741

3842
// resolve "generateClientHelpers" option
39-
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers');
43+
const generateClientHelpers = parseOptionAsStrings(options, 'generateClientHelpers', name);
4044
if (generateClientHelpers && !generateClientHelpers.every((v) => ['react', 'next'].includes(v))) {
4145
throw new PluginError(name, `Option "generateClientHelpers" only support values "react" and "next"`);
4246
}
@@ -50,10 +54,15 @@ export async function generate(model: Model, options: PluginOptions, dmmf: DMMF.
5054

5155
const prismaClientDmmf = dmmf;
5256

53-
const modelOperations = prismaClientDmmf.mappings.modelOperations;
54-
const models = prismaClientDmmf.datamodel.models;
57+
let modelOperations = prismaClientDmmf.mappings.modelOperations;
58+
if (generateModels) {
59+
modelOperations = modelOperations.filter((mo) => generateModels.includes(mo.model));
60+
}
61+
62+
// TODO: remove this legacy code that deals with "@Gen.hide" comment syntax inherited
63+
// from original code
5564
const hiddenModels: string[] = [];
56-
resolveModelsComments(models, hiddenModels);
65+
resolveModelsComments(prismaClientDmmf.datamodel.models, hiddenModels);
5766

5867
const zodSchemasImport = (options.zodSchemasImport as string) ?? '@zenstackhq/runtime/zod';
5968
createAppRouter(
@@ -472,24 +481,3 @@ function createHelper(outDir: string) {
472481
);
473482
checkRead.formatText();
474483
}
475-
476-
function parseOptionAsStrings(options: PluginOptions, optionaName: string) {
477-
const value = options[optionaName];
478-
if (value === undefined) {
479-
return undefined;
480-
} else if (typeof value === 'string') {
481-
// comma separated string
482-
return value
483-
.split(',')
484-
.filter((i) => !!i)
485-
.map((i) => i.trim());
486-
} else if (Array.isArray(value) && value.every((i) => typeof i === 'string')) {
487-
// string array
488-
return value as string[];
489-
} else {
490-
throw new PluginError(
491-
name,
492-
`Invalid "${optionaName}" option: must be a comma-separated string or an array of strings`
493-
);
494-
}
495-
}

packages/plugins/trpc/tests/trpc.test.ts

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,4 +285,133 @@ model post_item {
285285
}
286286
);
287287
});
288+
289+
it('generate for selected models and actions', async () => {
290+
const { projectDir } = await loadSchema(
291+
`
292+
datasource db {
293+
provider = 'postgresql'
294+
url = env('DATABASE_URL')
295+
}
296+
297+
generator js {
298+
provider = 'prisma-client-js'
299+
}
300+
301+
plugin trpc {
302+
provider = '${process.cwd()}/dist'
303+
output = '$projectRoot/trpc'
304+
generateModels = ['Post']
305+
generateModelActions = ['findMany', 'update']
306+
}
307+
308+
model User {
309+
id String @id
310+
email String @unique
311+
posts Post[]
312+
}
313+
314+
model Post {
315+
id String @id
316+
title String
317+
author User? @relation(fields: [authorId], references: [id])
318+
authorId String?
319+
}
320+
321+
model Foo {
322+
id String @id
323+
value Int
324+
}
325+
`,
326+
{
327+
addPrelude: false,
328+
pushDb: false,
329+
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
330+
compile: true,
331+
}
332+
);
333+
334+
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/User.router.ts'))).toBeFalsy();
335+
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Foo.router.ts'))).toBeFalsy();
336+
expect(fs.existsSync(path.join(projectDir, 'trpc/routers/Post.router.ts'))).toBeTruthy();
337+
338+
const postRouterContent = fs.readFileSync(path.join(projectDir, 'trpc/routers/Post.router.ts'), 'utf8');
339+
expect(postRouterContent).toContain('findMany:');
340+
expect(postRouterContent).toContain('update:');
341+
expect(postRouterContent).not.toContain('findUnique:');
342+
expect(postRouterContent).not.toContain('create:');
343+
344+
// trpc plugin passes "generateModels" option down to implicitly enabled zod plugin
345+
346+
expect(
347+
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
348+
).toBeTruthy();
349+
// zod for User is generated due to transitive dependency
350+
expect(
351+
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
352+
).toBeTruthy();
353+
expect(fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))).toBeFalsy();
354+
});
355+
356+
it('generate for selected models with zod plugin declared', async () => {
357+
const { projectDir } = await loadSchema(
358+
`
359+
datasource db {
360+
provider = 'postgresql'
361+
url = env('DATABASE_URL')
362+
}
363+
364+
generator js {
365+
provider = 'prisma-client-js'
366+
}
367+
368+
plugin zod {
369+
provider = '@core/zod'
370+
}
371+
372+
plugin trpc {
373+
provider = '${process.cwd()}/dist'
374+
output = '$projectRoot/trpc'
375+
generateModels = ['Post']
376+
generateModelActions = ['findMany', 'update']
377+
}
378+
379+
model User {
380+
id String @id
381+
email String @unique
382+
posts Post[]
383+
}
384+
385+
model Post {
386+
id String @id
387+
title String
388+
author User? @relation(fields: [authorId], references: [id])
389+
authorId String?
390+
}
391+
392+
model Foo {
393+
id String @id
394+
value Int
395+
}
396+
`,
397+
{
398+
addPrelude: false,
399+
pushDb: false,
400+
extraDependencies: [`${origDir}/dist`, '@trpc/client', '@trpc/server'],
401+
compile: true,
402+
}
403+
);
404+
405+
// trpc plugin's "generateModels" shouldn't interfere in this case
406+
407+
expect(
408+
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/PostInput.schema.js'))
409+
).toBeTruthy();
410+
expect(
411+
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/UserInput.schema.js'))
412+
).toBeTruthy();
413+
expect(
414+
fs.existsSync(path.join(projectDir, 'node_modules/.zenstack/zod/input/FooInput.schema.js'))
415+
).toBeTruthy();
416+
});
288417
});

packages/schema/src/cli/plugin-runner.ts

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ export class PluginRunner {
184184
}
185185

186186
// "@core/access-policy" has implicit requirements
187+
let zodImplicitlyAdded = false;
187188
if ([...plugins, ...corePlugins].find((p) => p.provider === '@core/access-policy')) {
188189
// make sure "@core/model-meta" is enabled
189190
if (!corePlugins.find((p) => p.provider === '@core/model-meta')) {
@@ -193,25 +194,52 @@ export class PluginRunner {
193194
// '@core/zod' plugin is auto-enabled by "@core/access-policy"
194195
// if there're validation rules
195196
if (!corePlugins.find((p) => p.provider === '@core/zod') && this.hasValidation(options.schema)) {
197+
zodImplicitlyAdded = true;
196198
corePlugins.push({ provider: '@core/zod', options: { modelOnly: true } });
197199
}
198200
}
199201

200202
// core plugins introduced by dependencies
201-
plugins
202-
.flatMap((p) => p.dependencies)
203-
.forEach((dep) => {
203+
plugins.forEach((plugin) => {
204+
// TODO: generalize this
205+
const isTrpcPlugin =
206+
plugin.provider === '@zenstackhq/trpc' ||
207+
// for testing
208+
(process.env.ZENSTACK_TEST && plugin.provider.includes('trpc'));
209+
210+
for (const dep of plugin.dependencies) {
204211
if (dep.startsWith('@core/')) {
205212
const existing = corePlugins.find((p) => p.provider === dep);
206213
if (existing) {
207-
// reset options to default
208-
existing.options = undefined;
214+
// TODO: generalize this
215+
if (existing.provider === '@core/zod') {
216+
// Zod plugin can be automatically enabled in `modelOnly` mode, however
217+
// other plugin (tRPC) for now requires it to run in full mode
218+
existing.options = {};
219+
220+
if (
221+
isTrpcPlugin &&
222+
zodImplicitlyAdded // don't do it for user defined zod plugin
223+
) {
224+
// pass trpc plugin's `generateModels` option down to zod plugin
225+
existing.options.generateModels = plugin.options.generateModels;
226+
}
227+
}
209228
} else {
210229
// add core dependency
211-
corePlugins.push({ provider: dep });
230+
const toAdd = { provider: dep, options: {} as Record<string, unknown> };
231+
232+
// TODO: generalize this
233+
if (dep === '@core/zod' && isTrpcPlugin) {
234+
// pass trpc plugin's `generateModels` option down to zod plugin
235+
toAdd.options.generateModels = plugin.options.generateModels;
236+
}
237+
238+
corePlugins.push(toAdd);
212239
}
213240
}
214-
});
241+
}
242+
});
215243

216244
return corePlugins;
217245
}

packages/schema/src/plugins/zod/generator.ts

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import {
1111
isEnumFieldReference,
1212
isForeignKeyField,
1313
isFromStdlib,
14+
parseOptionAsStrings,
1415
resolvePath,
1516
saveProject,
1617
} from '@zenstackhq/sdk';
@@ -21,6 +22,7 @@ import { streamAllContents } from 'langium';
2122
import path from 'path';
2223
import { Project } from 'ts-morph';
2324
import { upperCaseFirst } from 'upper-case-first';
25+
import { name } from '.';
2426
import { getDefaultOutputFolder } from '../plugin-utils';
2527
import Transformer from './transformer';
2628
import removeDir from './utils/removeDir';
@@ -44,12 +46,26 @@ export async function generate(
4446
output = resolvePath(output, options);
4547
await handleGeneratorOutputValue(output);
4648

49+
// calculate the models to be excluded
50+
const excludeModels = getExcludedModels(model, options);
51+
4752
const prismaClientDmmf = dmmf;
4853

49-
const modelOperations = prismaClientDmmf.mappings.modelOperations;
50-
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma;
51-
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma;
52-
const models: DMMF.Model[] = prismaClientDmmf.datamodel.models;
54+
const modelOperations = prismaClientDmmf.mappings.modelOperations.filter(
55+
(o) => !excludeModels.find((e) => e === o.model)
56+
);
57+
58+
// TODO: better way of filtering than string startsWith?
59+
const inputObjectTypes = prismaClientDmmf.schema.inputObjectTypes.prisma.filter(
60+
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLocaleLowerCase()))
61+
);
62+
const outputObjectTypes = prismaClientDmmf.schema.outputObjectTypes.prisma.filter(
63+
(type) => !excludeModels.find((e) => type.name.toLowerCase().startsWith(e.toLowerCase()))
64+
);
65+
66+
const models: DMMF.Model[] = prismaClientDmmf.datamodel.models.filter(
67+
(m) => !excludeModels.find((e) => e === m.name)
68+
);
5369

5470
// whether Prisma's Unchecked* series of input types should be generated
5571
const generateUnchecked = options.noUncheckedInput !== true;
@@ -73,7 +89,7 @@ export async function generate(
7389
dataSource?.fields.find((f) => f.name === 'provider')?.value
7490
) as ConnectorType;
7591

76-
await generateModelSchemas(project, model, output);
92+
await generateModelSchemas(project, model, output, excludeModels);
7793

7894
if (options.modelOnly !== true) {
7995
// detailed object schemas referenced from input schemas
@@ -120,6 +136,45 @@ export async function generate(
120136
}
121137
}
122138

139+
function getExcludedModels(model: Model, options: PluginOptions) {
140+
// resolve "generateModels" option
141+
const generateModels = parseOptionAsStrings(options, 'generateModels', name);
142+
if (generateModels) {
143+
if (options.modelOnly === true) {
144+
// no model reference needs to be considered, directly exclude any model not included
145+
return model.declarations
146+
.filter((d) => isDataModel(d) && !generateModels.includes(d.name))
147+
.map((m) => m.name);
148+
} else {
149+
// calculate a transitive closure of models to be included
150+
const todo = getDataModels(model).filter((dm) => generateModels.includes(dm.name));
151+
const included = new Set<DataModel>();
152+
while (todo.length > 0) {
153+
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
154+
const dm = todo.pop()!;
155+
included.add(dm);
156+
157+
// add referenced models to the todo list
158+
dm.fields
159+
.map((f) => f.type.reference?.ref)
160+
.filter((type): type is DataModel => isDataModel(type))
161+
.forEach((type) => {
162+
if (!included.has(type)) {
163+
todo.push(type);
164+
}
165+
});
166+
}
167+
168+
// finally find the models to be excluded
169+
return getDataModels(model)
170+
.filter((dm) => !included.has(dm))
171+
.map((m) => m.name);
172+
}
173+
} else {
174+
return [];
175+
}
176+
}
177+
123178
async function handleGeneratorOutputValue(output: string) {
124179
// create the output directory and delete contents that might exist from a previous run
125180
await fs.mkdir(output, { recursive: true });
@@ -184,10 +239,12 @@ async function generateObjectSchemas(
184239
);
185240
}
186241

187-
async function generateModelSchemas(project: Project, zmodel: Model, output: string) {
242+
async function generateModelSchemas(project: Project, zmodel: Model, output: string, excludedModels: string[]) {
188243
const schemaNames: string[] = [];
189244
for (const dm of getDataModels(zmodel)) {
190-
schemaNames.push(await generateModelSchema(dm, project, output));
245+
if (!excludedModels.includes(dm.name)) {
246+
schemaNames.push(await generateModelSchema(dm, project, output));
247+
}
191248
}
192249

193250
project.createSourceFile(

0 commit comments

Comments
 (0)