diff --git a/packages/language/syntaxes/zmodel.tmLanguage.json b/packages/language/syntaxes/zmodel.tmLanguage.json index 0afd6301d..2db106523 100644 --- a/packages/language/syntaxes/zmodel.tmLanguage.json +++ b/packages/language/syntaxes/zmodel.tmLanguage.json @@ -10,7 +10,7 @@ }, { "name": "keyword.control.zmodel", - "match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|function|generator|import|in|model|plugin|sort)\\b" + "match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|plugin|sort|true)\\b" }, { "name": "string.quoted.double.zmodel", diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index e7aa32b3c..59559e0f7 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -3,7 +3,16 @@ import type { DMMF } from '@prisma/generator-helper'; import { getDMMF } from '@prisma/internals'; import { isPlugin, Plugin } from '@zenstackhq/language/ast'; -import { getLiteral, getLiteralArray, PluginError, PluginFunction, PluginOptions, resolvePath } from '@zenstackhq/sdk'; +import { + getDataModels, + getLiteral, + getLiteralArray, + hasValidationAttributes, + PluginError, + PluginFunction, + PluginOptions, + resolvePath, +} from '@zenstackhq/sdk'; import colors from 'colors'; import fs from 'fs'; import ora from 'ora'; @@ -90,13 +99,21 @@ export class PluginRunner { } // make sure prerequisites are included - const corePlugins = [ - '@core/prisma', - '@core/model-meta', - '@core/access-policy', - // core dependencies introduced by dependencies - ...plugins.flatMap((p) => p.dependencies).filter((dep) => dep.startsWith('@core/')), - ]; + const corePlugins = ['@core/prisma', '@core/model-meta', '@core/access-policy']; + + if (getDataModels(context.schema).some((model) => hasValidationAttributes(model))) { + // '@core/zod' plugin is auto-enabled if there're validation rules + corePlugins.push('@core/zod'); + } + + // core dependencies introduced by dependencies + plugins + .flatMap((p) => p.dependencies) + .forEach((dep) => { + if (dep.startsWith('@core/') && !corePlugins.includes(dep)) { + corePlugins.push(dep); + } + }); for (const corePlugin of corePlugins.reverse()) { const existingIdx = plugins.findIndex((p) => p.provider === corePlugin); diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index c022de5bc..6758eb7d3 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -24,7 +24,7 @@ import { getLiteral, getPrismaClientImportSpec, GUARD_FIELD_NAME, - hasAttribute, + hasValidationAttributes, PluginError, PluginOptions, resolved, @@ -38,7 +38,7 @@ import path from 'path'; import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; -import { getIdFields, isAuthInvocation, VALIDATION_ATTRIBUTES } from '../../utils/ast-utils'; +import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; import { TypeScriptExpressionTransformer, TypeScriptExpressionTransformerError, @@ -113,7 +113,7 @@ export default class PolicyGenerator { for (const model of models) { writer.write(`${lowerCaseFirst(model.name)}:`); writer.inlineBlock(() => { - writer.write(`hasValidation: ${this.hasValidationAttributes(model)}`); + writer.write(`hasValidation: ${hasValidationAttributes(model)}`); }); writer.writeLine(','); } @@ -136,13 +136,6 @@ export default class PolicyGenerator { } } - private hasValidationAttributes(model: DataModel) { - return ( - hasAttribute(model, '@@validate') || - model.fields.some((field) => VALIDATION_ATTRIBUTES.some((attr) => hasAttribute(field, attr))) - ); - } - private getPolicyExpressions(model: DataModel, kind: PolicyKind, operation: PolicyOperationKind) { const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === `@@${kind}`); diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index 8c78707e7..d3a5574d1 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -166,6 +166,11 @@ function isEmpty(field: Any[]): Boolean { */ attribute @@@targetField(targetField: AttributeTargetField[]) +/** + * Marks an attribute to be used for data validation. + */ +attribute @@@validation() + /** * Indicates the expression context a function can be used. */ @@ -377,67 +382,67 @@ attribute @omit() /** * Validates length of a string field. */ -attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) +attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value starts with the given text. */ -attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField]) +attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value ends with the given text. */ -attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField]) +attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value contains the given text. */ -attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField]) +attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value matches a regex. */ -attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField]) +attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value is a valid email address. */ -attribute @email(_ message: String?) @@@targetField([StringField]) +attribute @email(_ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value is a valid ISO datetime. */ -attribute @datetime(_ message: String?) @@@targetField([StringField]) +attribute @datetime(_ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a string field value is a valid url. */ -attribute @url(_ message: String?) @@@targetField([StringField]) +attribute @url(_ message: String?) @@@targetField([StringField]) @@@validation /** * Validates a number field is greater than the given value. */ -attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) +attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation /** * Validates a number field is greater than or equal to the given value. */ -attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) +attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation /** * Validates a number field is less than the given value. */ -attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) +attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation /** * Validates a number field is less than or equal to the given value. */ -attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) +attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation /** * Validates the entity with a complex condition. */ -attribute @@validate(_ value: Boolean, _ message: String?) +attribute @@validate(_ value: Boolean, _ message: String?) @@@validation /** * Validates length of a string field. diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index b8e1b5a0a..c6ea1545f 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,6 +1,5 @@ import { DataModel, - DataModelAttribute, DataModelField, Expression, isArrayExpr, @@ -14,8 +13,6 @@ import { ModelImport, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { PolicyOperationKind } from '@zenstackhq/runtime'; -import { getLiteral } from '@zenstackhq/sdk'; import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; import { URI, Utils } from 'vscode-uri'; import { isFromStdlib } from '../language-server/utils'; @@ -26,31 +23,6 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] { ) as DataModel[]; } -export function analyzePolicies(dataModel: DataModel) { - const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); - const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); - - const create = toStaticPolicy('create', allows, denies); - const read = toStaticPolicy('read', allows, denies); - const update = toStaticPolicy('update', allows, denies); - const del = toStaticPolicy('delete', allows, denies); - const hasFieldValidation = dataModel.$resolvedFields.some((field) => - field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText)) - ); - - return { - allows, - denies, - create, - read, - update, - delete: del, - allowAll: create === true && read === true && update === true && del === true, - denyAll: create === false && read === false && update === false && del === false, - hasFieldValidation, - }; -} - export function mergeBaseModel(model: Model) { model.declarations .filter((x) => x.$type === 'DataModel') @@ -82,61 +54,6 @@ function updateContainer(nodes: T[], container: AstNode): Mut }); } -function toStaticPolicy( - operation: PolicyOperationKind, - allows: DataModelAttribute[], - denies: DataModelAttribute[] -): boolean | undefined { - const filteredDenies = forOperation(operation, denies); - if (filteredDenies.some((rule) => getLiteral(rule.args[1].value) === true)) { - // any constant true deny rule - return false; - } - - const filteredAllows = forOperation(operation, allows); - if (filteredAllows.length === 0) { - // no allow rule - return false; - } - - if ( - filteredDenies.length === 0 && - filteredAllows.some((rule) => getLiteral(rule.args[1].value) === true) - ) { - // any constant true allow rule - return true; - } - return undefined; -} - -function forOperation(operation: PolicyOperationKind, rules: DataModelAttribute[]) { - return rules.filter((rule) => { - const ops = getLiteral(rule.args[0].value); - if (!ops) { - return false; - } - if (ops === 'all') { - return true; - } - const splitOps = ops.split(',').map((p) => p.trim()); - return splitOps.includes(operation); - }); -} - -export const VALIDATION_ATTRIBUTES = [ - '@length', - '@regex', - '@startsWith', - '@endsWith', - '@email', - '@url', - '@datetime', - '@gt', - '@gte', - '@lt', - '@lte', -]; - export function getIdFields(dataModel: DataModel) { const fieldLevelId = dataModel.$resolvedFields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') diff --git a/packages/schema/tests/plugins/zod.test.ts b/packages/schema/tests/plugins/zod.test.ts index ab20bf9e6..5fb335143 100644 --- a/packages/schema/tests/plugins/zod.test.ts +++ b/packages/schema/tests/plugins/zod.test.ts @@ -25,10 +25,6 @@ describe('Zod plugin tests', () => { provider = 'prisma-client-js' } - plugin zod { - provider = '@core/zod' - } - enum Role { USER ADMIN @@ -123,10 +119,6 @@ describe('Zod plugin tests', () => { provider = 'prisma-client-js' } - plugin zod { - provider = '@core/zod' - } - model M { id Int @id @default(autoincrement()) a String? @length(5, 10, 'must be between 5 and 10') @@ -219,10 +211,6 @@ describe('Zod plugin tests', () => { provider = 'prisma-client-js' } - plugin zod { - provider = '@core/zod' - } - model M { id Int @id @default(autoincrement()) email String? @@ -286,10 +274,6 @@ describe('Zod plugin tests', () => { provider = 'prisma-client-js' } - plugin zod { - provider = '@core/zod' - } - model M { id Int @id @default(autoincrement()) arr Int[] diff --git a/packages/sdk/src/index.ts b/packages/sdk/src/index.ts index 53da49054..05d630ba5 100644 --- a/packages/sdk/src/index.ts +++ b/packages/sdk/src/index.ts @@ -3,4 +3,5 @@ export * from './constants'; export * from './types'; export * from './utils'; export * from './policy'; +export * from './validation'; export * from './prisma'; diff --git a/packages/sdk/src/policy.ts b/packages/sdk/src/policy.ts index ef10fa633..ccd3e851f 100644 --- a/packages/sdk/src/policy.ts +++ b/packages/sdk/src/policy.ts @@ -1,19 +1,6 @@ import type { DataModel, DataModelAttribute } from './ast'; import { getLiteral } from './utils'; - -export const VALIDATION_ATTRIBUTES = [ - '@length', - '@regex', - '@startsWith', - '@endsWith', - '@email', - '@url', - '@datetime', - '@gt', - '@gte', - '@lt', - '@lte', -]; +import { hasValidationAttributes } from './validation'; export function analyzePolicies(dataModel: DataModel) { const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); @@ -23,9 +10,7 @@ export function analyzePolicies(dataModel: DataModel) { const read = toStaticPolicy('read', allows, denies); const update = toStaticPolicy('update', allows, denies); const del = toStaticPolicy('delete', allows, denies); - const hasFieldValidation = dataModel.fields.some((field) => - field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText)) - ); + const hasFieldValidation = hasValidationAttributes(dataModel); return { allows, diff --git a/packages/sdk/src/validation.ts b/packages/sdk/src/validation.ts new file mode 100644 index 000000000..e7edc21fc --- /dev/null +++ b/packages/sdk/src/validation.ts @@ -0,0 +1,21 @@ +import type { DataModel, DataModelAttribute, DataModelFieldAttribute } from './ast'; + +function isValidationAttribute(attr: DataModelAttribute | DataModelFieldAttribute) { + return attr.decl.ref?.attributes.some((attr) => attr.decl.$refText === '@@@validation'); +} + +/** + * Returns if the given model contains any data validation rules (both at the model + * level and at the field level). + */ +export function hasValidationAttributes(model: DataModel) { + if (model.attributes.some((attr) => isValidationAttribute(attr))) { + return true; + } + + if (model.fields.some((field) => field.attributes.some((attr) => isValidationAttribute(attr)))) { + return true; + } + + return false; +} diff --git a/tests/integration/tests/schema/cal-com.zmodel b/tests/integration/tests/schema/cal-com.zmodel index 672c01b07..bc693d8d0 100644 --- a/tests/integration/tests/schema/cal-com.zmodel +++ b/tests/integration/tests/schema/cal-com.zmodel @@ -11,10 +11,6 @@ generator client { previewFeatures = [] } -plugin zod { - provider = '@core/zod' -} - enum SchedulingType { ROUND_ROBIN @map("roundRobin") COLLECTIVE @map("collective") diff --git a/tests/integration/tests/schema/todo.zmodel b/tests/integration/tests/schema/todo.zmodel index f107bf0fa..2bc57ac24 100644 --- a/tests/integration/tests/schema/todo.zmodel +++ b/tests/integration/tests/schema/todo.zmodel @@ -12,10 +12,6 @@ generator js { previewFeatures = ['clientExtensions'] } -plugin zod { - provider = '@core/zod' -} - /* * Model for a space in which users can collaborate on Lists and Todos */