Skip to content

fix: automatically enable "@core/zod" plugin when there're validation rules #535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/language/syntaxes/zmodel.tmLanguage.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
},
{
"name": "keyword.control.zmodel",
"match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|function|generator|import|in|model|plugin|sort)\\b"
"match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|false|function|generator|import|in|model|plugin|sort|true)\\b"
},
{
"name": "string.quoted.double.zmodel",
Expand Down
33 changes: 25 additions & 8 deletions packages/schema/src/cli/plugin-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@
import type { DMMF } from '@prisma/generator-helper';
import { getDMMF } from '@prisma/internals';
import { isPlugin, Plugin } from '@zenstackhq/language/ast';
import { getLiteral, getLiteralArray, PluginError, PluginFunction, PluginOptions, resolvePath } from '@zenstackhq/sdk';
import {
getDataModels,
getLiteral,
getLiteralArray,
hasValidationAttributes,
PluginError,
PluginFunction,
PluginOptions,
resolvePath,
} from '@zenstackhq/sdk';
import colors from 'colors';
import fs from 'fs';
import ora from 'ora';
Expand Down Expand Up @@ -90,13 +99,21 @@ export class PluginRunner {
}

// make sure prerequisites are included
const corePlugins = [
'@core/prisma',
'@core/model-meta',
'@core/access-policy',
// core dependencies introduced by dependencies
...plugins.flatMap((p) => p.dependencies).filter((dep) => dep.startsWith('@core/')),
];
const corePlugins = ['@core/prisma', '@core/model-meta', '@core/access-policy'];

if (getDataModels(context.schema).some((model) => hasValidationAttributes(model))) {
// '@core/zod' plugin is auto-enabled if there're validation rules
corePlugins.push('@core/zod');
}

// core dependencies introduced by dependencies
plugins
.flatMap((p) => p.dependencies)
.forEach((dep) => {
if (dep.startsWith('@core/') && !corePlugins.includes(dep)) {
corePlugins.push(dep);
}
});

for (const corePlugin of corePlugins.reverse()) {
const existingIdx = plugins.findIndex((p) => p.provider === corePlugin);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import {
getLiteral,
getPrismaClientImportSpec,
GUARD_FIELD_NAME,
hasAttribute,
hasValidationAttributes,
PluginError,
PluginOptions,
resolved,
Expand All @@ -38,7 +38,7 @@ import path from 'path';
import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph';
import { name } from '.';
import { isFromStdlib } from '../../language-server/utils';
import { getIdFields, isAuthInvocation, VALIDATION_ATTRIBUTES } from '../../utils/ast-utils';
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
import {
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
Expand Down Expand Up @@ -113,7 +113,7 @@ export default class PolicyGenerator {
for (const model of models) {
writer.write(`${lowerCaseFirst(model.name)}:`);
writer.inlineBlock(() => {
writer.write(`hasValidation: ${this.hasValidationAttributes(model)}`);
writer.write(`hasValidation: ${hasValidationAttributes(model)}`);
});
writer.writeLine(',');
}
Expand All @@ -136,13 +136,6 @@ export default class PolicyGenerator {
}
}

private hasValidationAttributes(model: DataModel) {
return (
hasAttribute(model, '@@validate') ||
model.fields.some((field) => VALIDATION_ATTRIBUTES.some((attr) => hasAttribute(field, attr)))
);
}

private getPolicyExpressions(model: DataModel, kind: PolicyKind, operation: PolicyOperationKind) {
const attrs = model.attributes.filter((attr) => attr.decl.ref?.name === `@@${kind}`);

Expand Down
31 changes: 18 additions & 13 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ function isEmpty(field: Any[]): Boolean {
*/
attribute @@@targetField(targetField: AttributeTargetField[])

/**
* Marks an attribute to be used for data validation.
*/
attribute @@@validation()

/**
* Indicates the expression context a function can be used.
*/
Expand Down Expand Up @@ -377,67 +382,67 @@ attribute @omit()
/**
* Validates length of a string field.
*/
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField])
attribute @length(_ min: Int?, _ max: Int?, _ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value starts with the given text.
*/
attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField])
attribute @startsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value ends with the given text.
*/
attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField])
attribute @endsWith(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value contains the given text.
*/
attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField])
attribute @contains(_ text: String, _ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value matches a regex.
*/
attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField])
attribute @regex(_ regex: String, _ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value is a valid email address.
*/
attribute @email(_ message: String?) @@@targetField([StringField])
attribute @email(_ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value is a valid ISO datetime.
*/
attribute @datetime(_ message: String?) @@@targetField([StringField])
attribute @datetime(_ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a string field value is a valid url.
*/
attribute @url(_ message: String?) @@@targetField([StringField])
attribute @url(_ message: String?) @@@targetField([StringField]) @@@validation

/**
* Validates a number field is greater than the given value.
*/
attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
attribute @gt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation

/**
* Validates a number field is greater than or equal to the given value.
*/
attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
attribute @gte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation

/**
* Validates a number field is less than the given value.
*/
attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
attribute @lt(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation

/**
* Validates a number field is less than or equal to the given value.
*/
attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField])
attribute @lte(_ value: Int, _ message: String?) @@@targetField([IntField, FloatField, DecimalField]) @@@validation

/**
* Validates the entity with a complex condition.
*/
attribute @@validate(_ value: Boolean, _ message: String?)
attribute @@validate(_ value: Boolean, _ message: String?) @@@validation

/**
* Validates length of a string field.
Expand Down
83 changes: 0 additions & 83 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import {
DataModel,
DataModelAttribute,
DataModelField,
Expression,
isArrayExpr,
Expand All @@ -14,8 +13,6 @@ import {
ModelImport,
ReferenceExpr,
} from '@zenstackhq/language/ast';
import { PolicyOperationKind } from '@zenstackhq/runtime';
import { getLiteral } from '@zenstackhq/sdk';
import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium';
import { URI, Utils } from 'vscode-uri';
import { isFromStdlib } from '../language-server/utils';
Expand All @@ -26,31 +23,6 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] {
) as DataModel[];
}

export function analyzePolicies(dataModel: DataModel) {
const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow');
const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny');

const create = toStaticPolicy('create', allows, denies);
const read = toStaticPolicy('read', allows, denies);
const update = toStaticPolicy('update', allows, denies);
const del = toStaticPolicy('delete', allows, denies);
const hasFieldValidation = dataModel.$resolvedFields.some((field) =>
field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText))
);

return {
allows,
denies,
create,
read,
update,
delete: del,
allowAll: create === true && read === true && update === true && del === true,
denyAll: create === false && read === false && update === false && del === false,
hasFieldValidation,
};
}

export function mergeBaseModel(model: Model) {
model.declarations
.filter((x) => x.$type === 'DataModel')
Expand Down Expand Up @@ -82,61 +54,6 @@ function updateContainer<T extends AstNode>(nodes: T[], container: AstNode): Mut
});
}

function toStaticPolicy(
operation: PolicyOperationKind,
allows: DataModelAttribute[],
denies: DataModelAttribute[]
): boolean | undefined {
const filteredDenies = forOperation(operation, denies);
if (filteredDenies.some((rule) => getLiteral<boolean>(rule.args[1].value) === true)) {
// any constant true deny rule
return false;
}

const filteredAllows = forOperation(operation, allows);
if (filteredAllows.length === 0) {
// no allow rule
return false;
}

if (
filteredDenies.length === 0 &&
filteredAllows.some((rule) => getLiteral<boolean>(rule.args[1].value) === true)
) {
// any constant true allow rule
return true;
}
return undefined;
}

function forOperation(operation: PolicyOperationKind, rules: DataModelAttribute[]) {
return rules.filter((rule) => {
const ops = getLiteral<string>(rule.args[0].value);
if (!ops) {
return false;
}
if (ops === 'all') {
return true;
}
const splitOps = ops.split(',').map((p) => p.trim());
return splitOps.includes(operation);
});
}

export const VALIDATION_ATTRIBUTES = [
'@length',
'@regex',
'@startsWith',
'@endsWith',
'@email',
'@url',
'@datetime',
'@gt',
'@gte',
'@lt',
'@lte',
];

export function getIdFields(dataModel: DataModel) {
const fieldLevelId = dataModel.$resolvedFields.find((f) =>
f.attributes.some((attr) => attr.decl.$refText === '@id')
Expand Down
16 changes: 0 additions & 16 deletions packages/schema/tests/plugins/zod.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@ describe('Zod plugin tests', () => {
provider = 'prisma-client-js'
}

plugin zod {
provider = '@core/zod'
}

enum Role {
USER
ADMIN
Expand Down Expand Up @@ -123,10 +119,6 @@ describe('Zod plugin tests', () => {
provider = 'prisma-client-js'
}

plugin zod {
provider = '@core/zod'
}

model M {
id Int @id @default(autoincrement())
a String? @length(5, 10, 'must be between 5 and 10')
Expand Down Expand Up @@ -219,10 +211,6 @@ describe('Zod plugin tests', () => {
provider = 'prisma-client-js'
}

plugin zod {
provider = '@core/zod'
}

model M {
id Int @id @default(autoincrement())
email String?
Expand Down Expand Up @@ -286,10 +274,6 @@ describe('Zod plugin tests', () => {
provider = 'prisma-client-js'
}

plugin zod {
provider = '@core/zod'
}

model M {
id Int @id @default(autoincrement())
arr Int[]
Expand Down
1 change: 1 addition & 0 deletions packages/sdk/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ export * from './constants';
export * from './types';
export * from './utils';
export * from './policy';
export * from './validation';
export * from './prisma';
19 changes: 2 additions & 17 deletions packages/sdk/src/policy.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,6 @@
import type { DataModel, DataModelAttribute } from './ast';
import { getLiteral } from './utils';

export const VALIDATION_ATTRIBUTES = [
'@length',
'@regex',
'@startsWith',
'@endsWith',
'@email',
'@url',
'@datetime',
'@gt',
'@gte',
'@lt',
'@lte',
];
import { hasValidationAttributes } from './validation';

export function analyzePolicies(dataModel: DataModel) {
const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow');
Expand All @@ -23,9 +10,7 @@ export function analyzePolicies(dataModel: DataModel) {
const read = toStaticPolicy('read', allows, denies);
const update = toStaticPolicy('update', allows, denies);
const del = toStaticPolicy('delete', allows, denies);
const hasFieldValidation = dataModel.fields.some((field) =>
field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText))
);
const hasFieldValidation = hasValidationAttributes(dataModel);

return {
allows,
Expand Down
Loading