diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index b9888eb9d..c8637115a 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -43,6 +43,17 @@ declare module './generated/ast' { */ $resolvedParam?: AttributeParam; } + + interface DataModel { + /** + * Resolved fields, include inherited fields + */ + $resolvedFields: Array; + } + + interface DataModelField { + $isInherited?: boolean; + } } declare module 'langium' { diff --git a/packages/language/src/generated/ast.ts b/packages/language/src/generated/ast.ts index aed415a7b..9783b2186 100644 --- a/packages/language/src/generated/ast.ts +++ b/packages/language/src/generated/ast.ts @@ -168,7 +168,9 @@ export interface DataModel extends AstNode { attributes: Array comments: Array fields: Array + isAbstract: boolean name: RegularID + superTypes: Array> } export const DataModel = 'DataModel'; @@ -645,6 +647,9 @@ export class ZModelAstReflection extends AbstractAstReflection { case 'FunctionParamType:reference': { return TypeDeclaration; } + case 'DataModel:superTypes': { + return DataModel; + } case 'InvocationExpr:function': { return FunctionDecl; } @@ -710,7 +715,9 @@ export class ZModelAstReflection extends AbstractAstReflection { mandatory: [ { name: 'attributes', type: 'array' }, { name: 'comments', type: 'array' }, - { name: 'fields', type: 'array' } + { name: 'fields', type: 'array' }, + { name: 'isAbstract', type: 'boolean' }, + { name: 'superTypes', type: 'array' } ] }; } diff --git a/packages/language/src/generated/grammar.ts b/packages/language/src/generated/grammar.ts index b7f2398f8..646a97650 100644 --- a/packages/language/src/generated/grammar.ts +++ b/packages/language/src/generated/grammar.ts @@ -1680,6 +1680,16 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel }, "cardinality": "*" }, + { + "$type": "Assignment", + "feature": "isAbstract", + "operator": "?=", + "terminal": { + "$type": "Keyword", + "value": "abstract" + }, + "cardinality": "?" + }, { "$type": "Keyword", "value": "model" @@ -1696,6 +1706,50 @@ export const ZModelGrammar = (): Grammar => loadedZModelGrammar ?? (loadedZModel "arguments": [] } }, + { + "$type": "Group", + "elements": [ + { + "$type": "Keyword", + "value": "extends" + }, + { + "$type": "Assignment", + "feature": "superTypes", + "operator": "+=", + "terminal": { + "$type": "CrossReference", + "type": { + "$ref": "#/rules@30" + }, + "deprecatedSyntax": false + } + }, + { + "$type": "Group", + "elements": [ + { + "$type": "Keyword", + "value": "," + }, + { + "$type": "Assignment", + "feature": "superTypes", + "operator": "+=", + "terminal": { + "$type": "CrossReference", + "type": { + "$ref": "#/rules@30" + }, + "deprecatedSyntax": false + } + } + ], + "cardinality": "*" + } + ], + "cardinality": "?" + }, { "$type": "Keyword", "value": "{" diff --git a/packages/language/src/zmodel.langium b/packages/language/src/zmodel.langium index 5c8ef8bd8..629c2e10e 100644 --- a/packages/language/src/zmodel.langium +++ b/packages/language/src/zmodel.langium @@ -156,8 +156,9 @@ Argument: // model DataModel: (comments+=TRIPLE_SLASH_COMMENT)* - 'model' name=RegularID '{' ( - fields+=DataModelField + (isAbstract?='abstract')? 'model' name=RegularID + ('extends' superTypes+=[DataModel] (',' superTypes+=[DataModel])*)? '{' ( + fields+=DataModelField | attributes+=DataModelAttribute )+ '}'; diff --git a/packages/language/syntaxes/zmodel.tmLanguage.json b/packages/language/syntaxes/zmodel.tmLanguage.json index 62ccd1c2b..0afd6301d 100644 --- a/packages/language/syntaxes/zmodel.tmLanguage.json +++ b/packages/language/syntaxes/zmodel.tmLanguage.json @@ -10,7 +10,7 @@ }, { "name": "keyword.control.zmodel", - "match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|attribute|datasource|enum|function|generator|import|in|model|plugin|sort)\\b" + "match": "\\b(Any|Asc|BigInt|Boolean|Bytes|ContextType|DateTime|Decimal|Desc|FieldReference|Float|Int|Json|Null|Object|String|TransitiveFieldReference|Unsupported|abstract|attribute|datasource|enum|extends|function|generator|import|in|model|plugin|sort)\\b" }, { "name": "string.quoted.double.zmodel", diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index ccc658e43..30ad3da62 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -12,7 +12,7 @@ import { URI } from 'vscode-uri'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../language-server/constants'; import { createZModelServices, ZModelServices } from '../language-server/zmodel-module'; import { Context } from '../types'; -import { resolveImport, resolveTransitiveImports } from '../utils/ast-utils'; +import { mergeBaseModel, resolveImport, resolveTransitiveImports } from '../utils/ast-utils'; import { ensurePackage, installPackage, PackageManagers } from '../utils/pkg-utils'; import { getVersion } from '../utils/version-utils'; import { CliError } from './cli-error'; @@ -125,7 +125,11 @@ export async function loadDocument(fileName: string): Promise { } ); - const validationErrors = (document.diagnostics ?? []).filter((e) => e.severity === 1); + const validationErrors = langiumDocuments.all + .flatMap((d) => d.diagnostics ?? []) + .filter((e) => e.severity === 1) + .toArray(); + if (validationErrors.length > 0) { console.error(colors.red('Validation errors:')); for (const validationError of validationErrors) { @@ -145,6 +149,9 @@ export async function loadDocument(fileName: string): Promise { mergeImportsDeclarations(langiumDocuments, model); validationAfterMerge(model); + + mergeBaseModel(model); + return model; } @@ -179,7 +186,9 @@ export function eagerLoadAllImports( } } - return Array.from(uris).map((e) => URI.parse(e)); + return Array.from(uris) + .filter((x) => uriString != x) + .map((e) => URI.parse(e)); } export function mergeImportsDeclarations(documents: LangiumDocuments, model: Model) { diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index e0934cc45..ec305330f 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -7,7 +7,7 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import { analyzePolicies, getLiteral } from '@zenstackhq/sdk'; -import { ValidationAcceptor } from 'langium'; +import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; import { getIdFields, getUniqueFields } from '../utils'; @@ -18,13 +18,14 @@ import { validateAttributeApplication, validateDuplicatedDeclarations } from './ */ export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { - validateDuplicatedDeclarations(dm.fields, accept); + this.validateBaseAbstractModel(dm, accept); + validateDuplicatedDeclarations(dm.$resolvedFields, accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - const idFields = dm.fields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); + const idFields = dm.$resolvedFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); const modelLevelIds = getIdFields(dm); if (idFields.length === 0 && modelLevelIds.length === 0) { @@ -57,6 +58,14 @@ export default class DataModelValidator implements AstValidator { } dm.fields.forEach((field) => this.validateField(field, accept)); + + if (!dm.isAbstract) { + dm.$resolvedFields + .filter((x) => isDataModel(x.type.reference?.ref)) + .forEach((y) => { + this.validateRelationField(y, accept); + }); + } } private validateField(field: DataModelField, accept: ValidationAcceptor): void { @@ -69,10 +78,6 @@ export default class DataModelValidator implements AstValidator { } field.attributes.forEach((attr) => validateAttributeApplication(attr, accept)); - - if (isDataModel(field.type.reference?.ref)) { - this.validateRelationField(field, accept); - } } private validateAttributes(dm: DataModel, accept: ValidationAcceptor) { @@ -175,8 +180,9 @@ export default class DataModelValidator implements AstValidator { if (relationName) { // field's relation points to another type, and that type's opposite relation field // points back - const oppositeModelFields = field.type.reference?.ref?.fields as DataModelField[]; - if (oppositeModelFields) { + const oppositeModel = field.type.reference?.ref as DataModel; + if (oppositeModel) { + const oppositeModelFields = oppositeModel.$resolvedFields as DataModelField[]; for (const oppositeField of oppositeModelFields) { // find the opposite relation with the matching name const relAttr = oppositeField.attributes.find((a) => a.decl.ref?.name === '@relation'); @@ -204,34 +210,68 @@ export default class DataModelValidator implements AstValidator { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion const oppositeModel = field.type.reference!.ref! as DataModel; - let oppositeFields = oppositeModel.fields.filter((f) => f.type.reference?.ref === field.$container); + // Use name because the current document might be updated + let oppositeFields = oppositeModel.$resolvedFields.filter( + (f) => f.type.reference?.ref?.name === field.$container.name + ); oppositeFields = oppositeFields.filter((f) => { const fieldRel = this.parseRelation(f); return fieldRel.valid && fieldRel.name === thisRelation.name; }); if (oppositeFields.length === 0) { + const node = field.$isInherited ? field.$container : field; + const info: DiagnosticInfo = { node, code: IssueCodes.MissingOppositeRelation }; + + let relationFieldDocUri: string; + let relationDataModelName: string; + + if (field.$isInherited) { + info.property = 'name'; + const container = field.$container as DataModel; + const abstractContainer = container.superTypes.find((x) => + x.ref?.fields.find((f) => f.name === field.name) + )?.ref as DataModel; + + relationFieldDocUri = getDocument(abstractContainer).textDocument.uri; + relationDataModelName = abstractContainer.name; + } else { + relationFieldDocUri = getDocument(field).textDocument.uri; + relationDataModelName = field.$container.name; + } + + const data: MissingOppositeRelationData = { + relationFieldName: field.name, + relationDataModelName, + relationFieldDocUri, + dataModelName: field.$container.name, + }; + + info.data = data; + accept( 'error', `The relation field "${field.name}" on model "${field.$container.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, - { node: field, code: IssueCodes.MissingOppositeRelation } + info ); return; } else if (oppositeFields.length > 1) { - oppositeFields.forEach((f) => { - if (this.isSelfRelation(f)) { - // self relations are partial - // https://www.prisma.io/docs/concepts/components/prisma-schema/relations/self-relations - } else { - accept( - 'error', - `Fields ${oppositeFields.map((f) => '"' + f.name + '"').join(', ')} on model "${ - oppositeModel.name - }" refer to the same relation to model "${field.$container.name}"`, - { node: f } - ); - } - }); + oppositeFields + .filter((x) => !x.$isInherited) + .forEach((f) => { + if (this.isSelfRelation(f)) { + // self relations are partial + // https://www.prisma.io/docs/concepts/components/prisma-schema/relations/self-relations + } else { + accept( + 'error', + `Fields ${oppositeFields.map((f) => '"' + f.name + '"').join(', ')} on model "${ + oppositeModel.name + }" refer to the same relation to model "${field.$container.name}"`, + { node: f } + ); + } + }); return; } @@ -317,4 +357,26 @@ export default class DataModelValidator implements AstValidator { }); } } + + private validateBaseAbstractModel(model: DataModel, accept: ValidationAcceptor) { + model.superTypes.forEach((superType, index) => { + if (!superType.ref?.isAbstract) + accept('error', `Model ${superType.$refText} cannot be extended because it's not abstract`, { + node: model, + property: 'superTypes', + index, + }); + }); + } +} + +export interface MissingOppositeRelationData { + relationDataModelName: string; + relationFieldName: string; + // it might be the abstract model in the imported document + relationFieldDocUri: string; + + // the name of DataModel that the relation field belongs to. + // the document is the same with the error node. + dataModelName: string; } diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index 4ee378b81..7b2f88b62 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -37,8 +37,16 @@ export function validateDuplicatedDeclarations( for (const [name, decls] of Object.entries(groupByName)) { if (decls.length > 1) { + let errorField = decls[1]; + if (decls[0].$type === 'DataModelField') { + const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$isInherited); + if (nonInheritedFields.length > 0) { + errorField = nonInheritedFields.slice(-1)[0]; + } + } + accept('error', `Duplicated declaration name "${name}"`, { - node: decls[1], + node: errorField, }); } } diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index 5638a7c2c..e9e7862ec 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -1,12 +1,11 @@ -import { DataModel, DataModelField, isDataModel } from '@zenstackhq/language/ast'; +import { DataModel, DataModelField, Model, isDataModel } from '@zenstackhq/language/ast'; import { AstReflection, CodeActionProvider, - findDeclarationNodeAtOffset, - getContainerOfType, getDocument, IndexManager, LangiumDocument, + LangiumDocuments, LangiumServices, MaybePromise, } from 'langium'; @@ -14,16 +13,19 @@ import { import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; import { IssueCodes } from './constants'; import { ZModelFormatter } from './zmodel-formatter'; +import { MissingOppositeRelationData } from './validator/datamodel-validator'; export class ZModelCodeActionProvider implements CodeActionProvider { protected readonly reflection: AstReflection; protected readonly indexManager: IndexManager; protected readonly formatter: ZModelFormatter; + protected readonly documents: LangiumDocuments; constructor(services: LangiumServices) { this.reflection = services.shared.AstReflection; this.indexManager = services.shared.workspace.IndexManager; this.formatter = services.lsp.Formatter as ZModelFormatter; + this.documents = services.shared.workspace.LangiumDocuments; } getCodeActions( @@ -52,20 +54,34 @@ export class ZModelCodeActionProvider implements CodeActionProvider { } private fixMissingOppositeRelation(diagnostic: Diagnostic, document: LangiumDocument): CodeAction | undefined { - const offset = document.textDocument.offsetAt(diagnostic.range.start); - const rootCst = document.parseResult.value.$cstNode; + const data = diagnostic.data as MissingOppositeRelationData; + + const rootCst = + data.relationFieldDocUri == document.textDocument.uri + ? document.parseResult.value + : this.documents.all.find((doc) => doc.textDocument.uri === data.relationFieldDocUri)?.parseResult + .value; if (rootCst) { - const cstNode = findDeclarationNodeAtOffset(rootCst, offset); + const fieldModel = rootCst as Model; + const fieldAstNode = ( + fieldModel.declarations.find( + (x) => isDataModel(x) && x.name === data.relationDataModelName + ) as DataModel + )?.fields.find((x) => x.name === data.relationFieldName) as DataModelField; - const astNode = cstNode?.element as DataModelField; + if (!fieldAstNode) return undefined; // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const oppositeModel = astNode.type.reference!.ref! as DataModel; + const oppositeModel = fieldAstNode.type.reference!.ref! as DataModel; const lastField = oppositeModel.fields[oppositeModel.fields.length - 1]; - const container = getContainerOfType(cstNode?.element, isDataModel) as DataModel; + const currentModel = document.parseResult.value as Model; + + const container = currentModel.declarations.find( + (decl) => decl.name === data.dataModelName && isDataModel(decl) + ) as DataModel; if (container && container.$cstNode) { // indent @@ -77,9 +93,9 @@ export class ZModelCodeActionProvider implements CodeActionProvider { indent = indent.repeat(this.formatter.getIndent()); let newText = ''; - if (astNode.type.array) { + if (fieldAstNode.type.array) { //post Post[] - const idField = container.fields.find((f) => + const idField = container.$resolvedFields.find((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id') ) as DataModelField; @@ -97,7 +113,7 @@ export class ZModelCodeActionProvider implements CodeActionProvider { const idFieldName = idField.name; const referenceIdFieldName = fieldName + this.upperCaseFirstLetter(idFieldName); - if (!oppositeModel.fields.find((f) => f.name === referenceIdFieldName)) { + if (!oppositeModel.$resolvedFields.find((f) => f.name === referenceIdFieldName)) { referenceField = '\n' + indent + `${referenceIdFieldName} ${idField.type.type}`; } diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index b7e828f65..d6c4eda03 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -67,6 +67,10 @@ export class ZModelLinker extends DefaultLinker { //#region Reference linking async link(document: LangiumDocument, cancelToken = CancellationToken.None): Promise { + if (document.parseResult.lexerErrors?.length > 0 || document.parseResult.parserErrors?.length > 0) { + return; + } + for (const node of streamContents(document.parseResult.value)) { await interruptAndCheck(cancelToken); this.resolve(node, document); @@ -156,6 +160,10 @@ export class ZModelLinker extends DefaultLinker { this.resolveAttributeArg(node as AttributeArg, document, extraScopes); break; + case DataModel: + this.resolveDataModel(node as DataModel, document, extraScopes); + break; + default: this.resolveDefault(node, document, extraScopes); break; @@ -299,7 +307,7 @@ export class ZModelLinker extends DefaultLinker { if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) { const modelDecl = operandResolved.decl as DataModel; - const provider = (name: string) => modelDecl.fields.find((f) => f.name === name); + const provider = (name: string) => modelDecl.$resolvedFields.find((f) => f.name === name); extraScopes = [provider, ...extraScopes]; } @@ -315,7 +323,7 @@ export class ZModelLinker extends DefaultLinker { const resolvedType = node.left.$resolvedType; if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) { const dataModelDecl = resolvedType.decl; - const provider = (name: string) => dataModelDecl.fields.find((f) => f.name === name); + const provider = (name: string) => dataModelDecl.$resolvedFields.find((f) => f.name === name); extraScopes = [provider, ...extraScopes]; this.resolve(node.right, document, extraScopes); this.resolveToBuiltinTypeOrDecl(node, 'Boolean'); @@ -377,7 +385,7 @@ export class ZModelLinker extends DefaultLinker { const transtiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; if (transtiveDataModel) { // resolve references in the context of the transitive data model - const scopeProvider = (name: string) => transtiveDataModel.fields.find((f) => f.name === name); + const scopeProvider = (name: string) => transtiveDataModel.$resolvedFields.find((f) => f.name === name); if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { @@ -432,6 +440,17 @@ export class ZModelLinker extends DefaultLinker { } } + private resolveDataModel(node: DataModel, document: LangiumDocument, extraScopes: ScopeProvider[]) { + if (node.superTypes.length > 0) { + const providers = node.superTypes.map( + (superType) => (name: string) => superType.ref?.fields.find((f) => f.name === name) + ); + extraScopes = [...providers, ...extraScopes]; + } + + return this.resolveDefault(node, document, extraScopes); + } + private resolveDefault(node: AstNode, document: LangiumDocument, extraScopes: ScopeProvider[]) { for (const [property, value] of Object.entries(node)) { if (!property.startsWith('$')) { diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 3b8316fdb..95fecacc7 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -1,4 +1,4 @@ -import { isEnumField, isModel } from '@zenstackhq/language/ast'; +import { isEnumField, isModel, Model, DataModel } from '@zenstackhq/language/ast'; import { AstNode, AstNodeDescription, @@ -10,6 +10,8 @@ import { interruptAndCheck, LangiumDocument, LangiumServices, + Mutable, + PrecomputedScopes, ReferenceInfo, Scope, stream, @@ -51,6 +53,42 @@ export class ZModelScopeComputation extends DefaultScopeComputation { return result; } + + override computeLocalScopes( + document: LangiumDocument, + cancelToken?: CancellationToken | undefined + ): Promise { + const result = super.computeLocalScopes(document, cancelToken); + + //the $resolvedFields would be used in Linking stage for all the documents + //so we need to set it at the end of the scope computation + this.resolveBaseModels(document); + return result; + } + + private resolveBaseModels(document: LangiumDocument) { + const model = document.parseResult.value as Model; + + model.declarations.forEach((decl) => { + if (decl.$type === 'DataModel') { + const dataModel = decl as DataModel; + dataModel.$resolvedFields = [...dataModel.fields]; + dataModel.superTypes.forEach((superType) => { + const superTypeDecl = superType.ref; + if (superTypeDecl) { + superTypeDecl.fields.forEach((field) => { + const cloneField = Object.assign({}, field); + cloneField.$isInherited = true; + const mutable = cloneField as Mutable; + // update container + mutable.$container = dataModel; + dataModel.$resolvedFields.push(cloneField); + }); + } + }); + } + }); + } } export class ZModelScopeProvider extends DefaultScopeProvider { diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index a23ac63e6..6ce870302 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -1,8 +1,10 @@ import { DataModel, + DataModelAttribute, DataModelField, Expression, isArrayExpr, + isDataModel, isDataModelField, isEnumField, isInvocationExpr, @@ -13,12 +15,135 @@ import { ModelImport, ReferenceExpr, } from '@zenstackhq/language/ast'; +import { PolicyOperationKind } from '@zenstackhq/runtime'; +import { getLiteral } from '@zenstackhq/sdk'; +import { AstNode, Mutable } from 'langium'; +import { isFromStdlib } from '../language-server/utils'; import { getDocument, LangiumDocuments } from 'langium'; import { URI, Utils } from 'vscode-uri'; -import { isFromStdlib } from '../language-server/utils'; + +export function extractDataModelsWithAllowRules(model: Model): DataModel[] { + return model.declarations.filter( + (d) => isDataModel(d) && d.attributes.some((attr) => attr.decl.ref?.name === '@@allow') + ) as DataModel[]; +} + +export function analyzePolicies(dataModel: DataModel) { + const allows = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@allow'); + const denies = dataModel.attributes.filter((attr) => attr.decl.ref?.name === '@@deny'); + + const create = toStaticPolicy('create', allows, denies); + const read = toStaticPolicy('read', allows, denies); + const update = toStaticPolicy('update', allows, denies); + const del = toStaticPolicy('delete', allows, denies); + const hasFieldValidation = dataModel.$resolvedFields.some((field) => + field.attributes.some((attr) => VALIDATION_ATTRIBUTES.includes(attr.decl.$refText)) + ); + + return { + allows, + denies, + create, + read, + update, + delete: del, + allowAll: create === true && read === true && update === true && del === true, + denyAll: create === false && read === false && update === false && del === false, + hasFieldValidation, + }; +} + +export function mergeBaseModel(model: Model) { + model.declarations + .filter((x) => x.$type === 'DataModel') + .forEach((decl) => { + const dataModel = decl as DataModel; + + dataModel.superTypes.forEach((superType) => { + const superTypeDecl = superType.ref; + if (superTypeDecl) { + superTypeDecl.fields.forEach((field) => { + const cloneField = Object.assign({}, field); + const mutable = cloneField as Mutable; + // update container + mutable.$container = dataModel; + dataModel.fields.push(mutable as DataModelField); + }); + + superTypeDecl.attributes.forEach((attr) => { + const cloneAttr = Object.assign({}, attr); + const mutable = cloneAttr as Mutable; + // update container + mutable.$container = dataModel; + dataModel.attributes.push(mutable as DataModelAttribute); + }); + } + }); + }); + + // remove abstract models + model.declarations = model.declarations.filter((x) => !(x.$type == 'DataModel' && x.isAbstract)); +} + +function toStaticPolicy( + operation: PolicyOperationKind, + allows: DataModelAttribute[], + denies: DataModelAttribute[] +): boolean | undefined { + const filteredDenies = forOperation(operation, denies); + if (filteredDenies.some((rule) => getLiteral(rule.args[1].value) === true)) { + // any constant true deny rule + return false; + } + + const filteredAllows = forOperation(operation, allows); + if (filteredAllows.length === 0) { + // no allow rule + return false; + } + + if ( + filteredDenies.length === 0 && + filteredAllows.some((rule) => getLiteral(rule.args[1].value) === true) + ) { + // any constant true allow rule + return true; + } + return undefined; +} + +function forOperation(operation: PolicyOperationKind, rules: DataModelAttribute[]) { + return rules.filter((rule) => { + const ops = getLiteral(rule.args[0].value); + if (!ops) { + return false; + } + if (ops === 'all') { + return true; + } + const splitOps = ops.split(',').map((p) => p.trim()); + return splitOps.includes(operation); + }); +} + +export const VALIDATION_ATTRIBUTES = [ + '@length', + '@regex', + '@startsWith', + '@endsWith', + '@email', + '@url', + '@datetime', + '@gt', + '@gte', + '@lt', + '@lte', +]; export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.fields.find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id')); + const fieldLevelId = dataModel.$resolvedFields.find((f) => + f.attributes.some((attr) => attr.decl.$refText === '@id') + ); if (fieldLevelId) { return [fieldLevelId]; } else { diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 087b8d72a..fc4bfe8de 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -236,6 +236,41 @@ describe('Prisma generator test', () => { expect(content).toContain('@@schema("base")'); expect(content).toContain('schemas = ["base","transactional"]'); }); + + it('abstract model', async () => { + const model = await loadModel(` + datasource db { + provider = 'postgresql' + url = env('URL') + } + abstract model Base { + id String @id + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + } + + model Post extends Base { + title String + published Boolean @default(false) + } + `); + const { name } = tmp.fileSync({ postfix: '.prisma' }); + await new PrismaSchemaGenerator().generate(model, { + provider: '@core/prisma', + schemaPath: 'schema.zmodel', + output: name, + generateClient: false, + }); + + const content = fs.readFileSync(name, 'utf-8'); + const dmmf = await getDMMF({ datamodel: content }); + + expect(dmmf.datamodel.models.length).toBe(1); + const post = dmmf.datamodel.models[0]; + expect(post.name).toBe('Post'); + expect(post.fields.length).toBe(6); + }); + it('custom aux field names', async () => { const model = await loadModel(` datasource db { @@ -267,7 +302,7 @@ describe('Prisma generator test', () => { expect(content).toContain('@map("myTransactionField")'); }); - it('multi files', async () => { + it('abstract multi files', async () => { const model = await loadDocument(path.join(__dirname, './zmodel/schema.zmodel')); const { name } = tmp.fileSync({ postfix: '.prisma' }); @@ -275,12 +310,22 @@ describe('Prisma generator test', () => { provider: '@core/prisma', schemaPath: 'schema.zmodel', output: name, + generateClient: false, }); const content = fs.readFileSync(name, 'utf-8'); const dmmf = await getDMMF({ datamodel: content }); - expect(dmmf.datamodel.models.length).toBe(2); + expect(dmmf.datamodel.models.length).toBe(3); expect(dmmf.datamodel.enums[0].name).toBe('UserRole'); + + const post = dmmf.datamodel.models.find((m) => m.name === 'Post'); + + expect(post?.documentation?.replace(/\s/g, '')).toBe( + `@@allow('delete', ownerId == auth()) @@allow('read', owner == auth())`.replace(/\s/g, '') + ); + + const todo = dmmf.datamodel.models.find((m) => m.name === 'Todo'); + expect(todo?.documentation?.replace(/\s/g, '')).toBe(`@@allow('read', owner == auth())`.replace(/\s/g, '')); }); }); diff --git a/packages/schema/tests/generator/zmodel/schema.zmodel b/packages/schema/tests/generator/zmodel/schema.zmodel index be912e3fb..9e2c6a803 100644 --- a/packages/schema/tests/generator/zmodel/schema.zmodel +++ b/packages/schema/tests/generator/zmodel/schema.zmodel @@ -5,10 +5,14 @@ datasource db { url = env('URL') } -model Post { - id Int @id() @default(autoincrement()) - author User? @relation(fields: [authorId], references: [id]) - authorId Int? - // author has full access - @@allow('all', auth() == author) +model Post extends Basic { + title String + content String? + + @@allow('delete', ownerId == auth()) +} + +model Todo extends Basic { + title String + isCompleted Boolean } \ No newline at end of file diff --git a/packages/schema/tests/generator/zmodel/user/user.zmodel b/packages/schema/tests/generator/zmodel/user/user.zmodel index c83eb6fbe..eb351a627 100644 --- a/packages/schema/tests/generator/zmodel/user/user.zmodel +++ b/packages/schema/tests/generator/zmodel/user/user.zmodel @@ -1,9 +1,10 @@ import "../schema" model User { - id Int @id() @default(autoincrement()) + id String @id() @default(uuid()) email String @unique() name String? posts Post[] + todos Todo[] role UserRole // make user profile public @@ -14,4 +15,14 @@ model User { enum UserRole { USER ADMIN +} + +abstract model Basic { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade) + ownerId String + + @@allow('read', owner == auth()) } \ No newline at end of file diff --git a/packages/schema/tests/schema/abstract.test.ts b/packages/schema/tests/schema/abstract.test.ts new file mode 100644 index 000000000..a3364bc2e --- /dev/null +++ b/packages/schema/tests/schema/abstract.test.ts @@ -0,0 +1,12 @@ +import * as fs from 'fs'; +import path from 'path'; +import { loadModel } from '../utils'; + +describe('Abstract Schema Tests', () => { + it('model loading', async () => { + const content = fs.readFileSync(path.join(__dirname, './abstract.zmodel'), { + encoding: 'utf-8', + }); + await loadModel(content); + }); +}); diff --git a/packages/schema/tests/schema/abstract.zmodel b/packages/schema/tests/schema/abstract.zmodel new file mode 100644 index 000000000..d49a95640 --- /dev/null +++ b/packages/schema/tests/schema/abstract.zmodel @@ -0,0 +1,33 @@ +datasource db { + provider = 'postgresql' + url = env('DATABASE_URL') +} + +generator js { + provider = 'prisma-client-js' +} + +abstract model Base { + id Int @id @default(autoincrement()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + user User @relation(fields: [userId], references: [id]) + userId String +} + + +model Post extends Base { + title String + published Boolean @default(false) +} + +model Todo extends Base { + description String + isDone Boolean @default(false) +} + +model User { + id String @id + todos Todo[] + posts Post[] +} \ No newline at end of file diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index f9a3478fc..daa7c381e 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -590,4 +590,25 @@ describe('Data Model Validation Tests', () => { } `); }); + + it('abstract base type', async () => { + const errors = await loadModelWithError(` + ${prelude} + + abstract model Base { + id String @id + } + + model A { + a String + } + + model B extends Base,A { + b String + } + `); + expect(errors.length).toBe(1); + + expect(errors[0]).toEqual(`Model A cannot be extended because it's not abstract`); + }); }); diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts index 91eb81cde..f362b4019 100644 --- a/packages/schema/tests/utils.ts +++ b/packages/schema/tests/utils.ts @@ -5,6 +5,7 @@ import * as path from 'path'; import * as tmp from 'tmp'; import { URI } from 'vscode-uri'; import { createZModelServices } from '../src/language-server/zmodel-module'; +import { mergeBaseModel } from '../src/utils/ast-utils'; export class SchemaLoadingError extends Error { constructor(public readonly errors: string[]) { @@ -49,6 +50,9 @@ export async function loadModel(content: string, validate = true, verbose = true } const model = (await doc.parseResult.value) as Model; + + mergeBaseModel(model); + return model; }