Skip to content

Commit 0519421

Browse files
authored
fix: automatically enable "@core/zod" plugin when there're validation rules (#535)
1 parent bc8e0c0 commit 0519421

File tree

11 files changed

+71
-156
lines changed

11 files changed

+71
-156
lines changed

packages/language/syntaxes/zmodel.tmLanguage.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
},
1111
{
1212
"name": "keyword.control.zmodel",
13-
"match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|function|generator|import|in|model|plugin|sort)\\b"
13+
"match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|plugin|sort|true)\\b"
1414
},
1515
{
1616
"name": "string.quoted.double.zmodel",

packages/schema/src/cli/plugin-runner.ts

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
import type { DMMF } from '@prisma/generator-helper';
44
import { getDMMF } from '@prisma/internals';
55
import { isPlugin, Plugin } from '@zenstackhq/language/ast';
6-
import { getLiteral, getLiteralArray, PluginError, PluginFunction, PluginOptions, resolvePath } from '@zenstackhq/sdk';
6+
import {
7+
getDataModels,
8+
getLiteral,
9+
getLiteralArray,
10+
hasValidationAttributes,
11+
PluginError,
12+
PluginFunction,
13+
PluginOptions,
14+
resolvePath,
15+
} from '@zenstackhq/sdk';
716
import colors from 'colors';
817
import fs from 'fs';
918
import ora from 'ora';
@@ -90,13 +99,21 @@ export class PluginRunner {
9099
}
91100

92101
// make sure prerequisites are included
93-
const corePlugins = [
94-
'@core/prisma',
95-
'@core/model-meta',
96-
'@core/access-policy',
97-
// core dependencies introduced by dependencies
98-
...plugins.flatMap((p) => p.dependencies).filter((dep) => dep.startsWith('@core/')),
99-
];
102+
const corePlugins = ['@core/prisma', '@core/model-meta', '@core/access-policy'];
103+
104+
if (getDataModels(context.schema).some((model) => hasValidationAttributes(model))) {
105+
// '@core/zod' plugin is auto-enabled if there're validation rules
106+
corePlugins.push('@core/zod');
107+
}
108+
109+
// core dependencies introduced by dependencies
110+
plugins
111+
.flatMap((p) => p.dependencies)
112+
.forEach((dep) => {
113+
if (dep.startsWith('@core/') && !corePlugins.includes(dep)) {
114+
corePlugins.push(dep);
115+
}
116+
});
100117

101118
for (const corePlugin of corePlugins.reverse()) {
102119
const existingIdx = plugins.findIndex((p) => p.provider === corePlugin);

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
getLiteral,
2525
getPrismaClientImportSpec,
2626
GUARD_FIELD_NAME,
27-
hasAttribute,
27+
hasValidationAttributes,
2828
PluginError,
2929
PluginOptions,
3030
resolved,
@@ -38,7 +38,7 @@ import path from 'path';
3838
import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph';
3939
import { name } from '.';
4040
import { isFromStdlib } from '../../language-server/utils';
41-
import { getIdFields, isAuthInvocation, VALIDATION_ATTRIBUTES } from '../../utils/ast-utils';
41+
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
4242
import {
4343
TypeScriptExpressionTransformer,
4444
TypeScriptExpressionTransformerError,
@@ -113,7 +113,7 @@ export default class PolicyGenerator {
113113
for (const model of models) {
114114
writer.write(`${lowerCaseFirst(model.name)}:`);
115115
writer.inlineBlock(() => {
116-
writer.write(`hasValidation: ${this.hasValidationAttributes(model)}`);
116+
writer.write(`hasValidation: ${hasValidationAttributes(model)}`);
117117
});
118118
writer.writeLine(',');
119119
}
@@ -136,13 +136,6 @@ export default class PolicyGenerator {
136136
}
137137
}
138138

139-
private hasValidationAttributes(model: DataModel) {
140-
return (
141-
hasAttribute(model, '@@validate') ||
142-
model.fields.some((field) => VALIDATION_ATTRIBUTES.some((attr) => hasAttribute(field, attr)))
143-
);
144-
}
145-
146139
private getPolicyExpressions(model: DataModel, kind: PolicyKind, operation: PolicyOperationKind) {
147140
const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === `@@${kind}`);
148141

packages/schema/src/res/stdlib.zmodel

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ function isEmpty(field: Any[]): Boolean {
166166
*/
167167
attribute @@@targetField(targetField: AttributeTargetField[])
168168

169+
/**
170+
* Marks an attribute to be used for data validation.
171+
*/
172+
attribute @@@validation()
173+
169174
/**
170175
* Indicates the expression context a function can be used.
171176
*/
@@ -377,67 +382,67 @@ attribute @omit()
377382
/**
378383
* Validates length of a string field.
379384
*/
380-
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField])
385+
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation
381386

382387
/**
383388
* Validates a string field value starts with the given text.
384389
*/
385-
attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField])
390+
attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation
386391

387392
/**
388393
* Validates a string field value ends with the given text.
389394
*/
390-
attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField])
395+
attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation
391396

392397
/**
393398
* Validates a string field value contains the given text.
394399
*/
395-
attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField])
400+
attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation
396401

397402
/**
398403
* Validates a string field value matches a regex.
399404
*/
400-
attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField])
405+
attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField]) @@@validation
401406

402407
/**
403408
* Validates a string field value is a valid email address.
404409
*/
405-
attribute @email(_ message: String?) @@@targetField([StringField])
410+
attribute @email(_ message: String?) @@@targetField([StringField]) @@@validation
406411

407412
/**
408413
* Validates a string field value is a valid ISO datetime.
409414
*/
410-
attribute @datetime(_ message: String?) @@@targetField([StringField])
415+
attribute @datetime(_ message: String?) @@@targetField([StringField]) @@@validation
411416

412417
/**
413418
* Validates a string field value is a valid url.
414419
*/
415-
attribute @url(_ message: String?) @@@targetField([StringField])
420+
attribute @url(_ message: String?) @@@targetField([StringField]) @@@validation
416421

417422
/**
418423
* Validates a number field is greater than the given value.
419424
*/
420-
attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
425+
attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation
421426

422427
/**
423428
* Validates a number field is greater than or equal to the given value.
424429
*/
425-
attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
430+
attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation
426431

427432
/**
428433
* Validates a number field is less than the given value.
429434
*/
430-
attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
435+
attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation
431436

432437
/**
433438
* Validates a number field is less than or equal to the given value.
434439
*/
435-
attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
440+
attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation
436441

437442
/**
438443
* Validates the entity with a complex condition.
439444
*/
440-
attribute @@validate(_ value: Boolean, _ message: String?)
445+
attribute @@validate(_ value: Boolean, _ message: String?) @@@validation
441446

442447
/**
443448
* Validates length of a string field.

packages/schema/src/utils/ast-utils.ts

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import {
22
DataModel,
3-
DataModelAttribute,
43
DataModelField,
54
Expression,
65
isArrayExpr,
@@ -14,8 +13,6 @@ import {
1413
ModelImport,
1514
ReferenceExpr,
1615
} from '@zenstackhq/language/ast';
17-
import { PolicyOperationKind } from '@zenstackhq/runtime';
18-
import { getLiteral } from '@zenstackhq/sdk';
1916
import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium';
2017
import { URI, Utils } from 'vscode-uri';
2118
import { isFromStdlib } from '../language-server/utils';
@@ -26,31 +23,6 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] {
2623
) as DataModel[];
2724
}
2825

29-
export function analyzePolicies(dataModel: DataModel) {
30-
const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow');
31-
const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny');
32-
33-
const create = toStaticPolicy('create', allows, denies);
34-
const read = toStaticPolicy('read', allows, denies);
35-
const update = toStaticPolicy('update', allows, denies);
36-
const del = toStaticPolicy('delete', allows, denies);
37-
const hasFieldValidation = dataModel.$resolvedFields.some((field) =>
38-
field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText))
39-
);
40-
41-
return {
42-
allows,
43-
denies,
44-
create,
45-
read,
46-
update,
47-
delete: del,
48-
allowAll: create === true && read === true && update === true && del === true,
49-
denyAll: create === false && read === false && update === false && del === false,
50-
hasFieldValidation,
51-
};
52-
}
53-
5426
export function mergeBaseModel(model: Model) {
5527
model.declarations
5628
.filter((x) => x.$type === 'DataModel')
@@ -82,61 +54,6 @@ function updateContainer<T extends AstNode>(nodes: T[], container: AstNode): Mut
8254
});
8355
}
8456

85-
function toStaticPolicy(
86-
operation: PolicyOperationKind,
87-
allows: DataModelAttribute[],
88-
denies: DataModelAttribute[]
89-
): boolean | undefined {
90-
const filteredDenies = forOperation(operation, denies);
91-
if (filteredDenies.some((rule) => getLiteral<boolean>(rule.args[1].value) === true)) {
92-
// any constant true deny rule
93-
return false;
94-
}
95-
96-
const filteredAllows = forOperation(operation, allows);
97-
if (filteredAllows.length === 0) {
98-
// no allow rule
99-
return false;
100-
}
101-
102-
if (
103-
filteredDenies.length === 0 &&
104-
filteredAllows.some((rule) => getLiteral<boolean>(rule.args[1].value) === true)
105-
) {
106-
// any constant true allow rule
107-
return true;
108-
}
109-
return undefined;
110-
}
111-
112-
function forOperation(operation: PolicyOperationKind, rules: DataModelAttribute[]) {
113-
return rules.filter((rule) => {
114-
const ops = getLiteral<string>(rule.args[0].value);
115-
if (!ops) {
116-
return false;
117-
}
118-
if (ops === 'all') {
119-
return true;
120-
}
121-
const splitOps = ops.split(',').map((p) => p.trim());
122-
return splitOps.includes(operation);
123-
});
124-
}
125-
126-
export const VALIDATION_ATTRIBUTES = [
127-
'@length',
128-
'@regex',
129-
'@startsWith',
130-
'@endsWith',
131-
'@email',
132-
'@url',
133-
'@datetime',
134-
'@gt',
135-
'@gte',
136-
'@lt',
137-
'@lte',
138-
];
139-
14057
export function getIdFields(dataModel: DataModel) {
14158
const fieldLevelId = dataModel.$resolvedFields.find((f) =>
14259
f.attributes.some((attr) => attr.decl.$refText === '@id')

packages/schema/tests/plugins/zod.test.ts

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,6 @@ describe('Zod plugin tests', () => {
2525
provider = 'prisma-client-js'
2626
}
2727
28-
plugin zod {
29-
provider = '@core/zod'
30-
}
31-
3228
enum Role {
3329
USER
3430
ADMIN
@@ -123,10 +119,6 @@ describe('Zod plugin tests', () => {
123119
provider = 'prisma-client-js'
124120
}
125121
126-
plugin zod {
127-
provider = '@core/zod'
128-
}
129-
130122
model M {
131123
id Int @id @default(autoincrement())
132124
a String? @length(5, 10, 'must be between 5 and 10')
@@ -219,10 +211,6 @@ describe('Zod plugin tests', () => {
219211
provider = 'prisma-client-js'
220212
}
221213
222-
plugin zod {
223-
provider = '@core/zod'
224-
}
225-
226214
model M {
227215
id Int @id @default(autoincrement())
228216
email String?
@@ -286,10 +274,6 @@ describe('Zod plugin tests', () => {
286274
provider = 'prisma-client-js'
287275
}
288276
289-
plugin zod {
290-
provider = '@core/zod'
291-
}
292-
293277
model M {
294278
id Int @id @default(autoincrement())
295279
arr Int[]

packages/sdk/src/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@ export * from './constants';
33
export * from './types';
44
export * from './utils';
55
export * from './policy';
6+
export * from './validation';
67
export * from './prisma';

packages/sdk/src/policy.ts

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,6 @@
11
import type { DataModel, DataModelAttribute } from './ast';
22
import { getLiteral } from './utils';
3-
4-
export const VALIDATION_ATTRIBUTES = [
5-
'@length',
6-
'@regex',
7-
'@startsWith',
8-
'@endsWith',
9-
'@email',
10-
'@url',
11-
'@datetime',
12-
'@gt',
13-
'@gte',
14-
'@lt',
15-
'@lte',
16-
];
3+
import { hasValidationAttributes } from './validation';
174

185
export function analyzePolicies(dataModel: DataModel) {
196
const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow');
@@ -23,9 +10,7 @@ export function analyzePolicies(dataModel: DataModel) {
2310
const read = toStaticPolicy('read', allows, denies);
2411
const update = toStaticPolicy('update', allows, denies);
2512
const del = toStaticPolicy('delete', allows, denies);
26-
const hasFieldValidation = dataModel.fields.some((field) =>
27-
field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText))
28-
);
13+
const hasFieldValidation = hasValidationAttributes(dataModel);
2914

3015
return {
3116
allows,

0 commit comments

Comments
 (0)