diff --git a/packages/language/src/ast.ts b/packages/language/src/ast.ts index c8637115a..86dd55bed 100644 --- a/packages/language/src/ast.ts +++ b/packages/language/src/ast.ts @@ -1,7 +1,8 @@ -import { AbstractDeclaration, ExpressionType, BinaryExpr } from './generated/ast'; +import { AstNode } from 'langium'; +import { AbstractDeclaration, BinaryExpr, DataModel, ExpressionType } from './generated/ast'; -export * from './generated/ast'; export { AstNode, Reference } from 'langium'; +export * from './generated/ast'; /** * Shape of type resolution result: an expression type or reference to a declaration @@ -44,18 +45,19 @@ declare module './generated/ast' { $resolvedParam?: AttributeParam; } - interface DataModel { - /** - * Resolved fields, include inherited fields - */ - $resolvedFields: Array; + interface DataModelField { + $inheritedFrom?: DataModel; } - interface DataModelField { - $isInherited?: boolean; + interface DataModelAttribute { + $inheritedFrom?: DataModel; } } +export interface InheritableNode extends AstNode { + $inheritedFrom?: DataModel; +} + declare module 'langium' { export interface AstNode { /** diff --git a/packages/runtime/src/enhancements/create-enhancement.ts b/packages/runtime/src/enhancements/create-enhancement.ts index e3204cd52..b137e03f9 100644 --- a/packages/runtime/src/enhancements/create-enhancement.ts +++ b/packages/runtime/src/enhancements/create-enhancement.ts @@ -2,12 +2,12 @@ import semver from 'semver'; import { PRISMA_MINIMUM_VERSION } from '../constants'; import { ModelMeta } from '../cross'; import type { AuthUser } from '../types'; +import { withDefaultAuth } from './default-auth'; import { withOmit } from './omit'; import { withPassword } from './password'; import { withPolicy } from './policy'; import type { ErrorTransformer } from './proxy'; import type { PolicyDef, ZodSchemas } from './types'; -import { withDefaultAuth } from './default-auth'; /** * Kinds of enhancements to `PrismaClient` diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index 2879a3119..ba2f9a2d8 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -22,18 +22,3 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error { throw new prismaModule.PrismaClientUnknownRequestError(...args); } - -export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown { - if (path === undefined || path === '') { - return defaultValue; - } - const keys = Array.isArray(path) ? path : path.split('.'); - for (const key of keys) { - if (object && typeof object === 'object' && key in object) { - object = object[key as keyof typeof object]; - } else { - return defaultValue; - } - } - return object !== undefined ? object : defaultValue; -} diff --git a/packages/schema/src/cli/cli-util.ts b/packages/schema/src/cli/cli-util.ts index 000e92ca7..2cfa18fcb 100644 --- a/packages/schema/src/cli/cli-util.ts +++ b/packages/schema/src/cli/cli-util.ts @@ -89,7 +89,7 @@ export async function loadDocument(fileName: string): Promise { validationAfterMerge(model); - mergeBaseModel(model); + mergeBaseModel(model, services.references.Linker); return model; } diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index ce1886f5e..33ec0ff37 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -6,7 +6,13 @@ import { isStringLiteral, ReferenceExpr, } from '@zenstackhq/language/ast'; -import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields } from '@zenstackhq/sdk'; +import { + analyzePolicies, + getLiteral, + getModelFieldsWithBases, + getModelIdFields, + getModelUniqueFields, +} from '@zenstackhq/sdk'; import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; @@ -20,16 +26,15 @@ import { validateDuplicatedDeclarations } from './utils'; export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { this.validateBaseAbstractModel(dm, accept); - validateDuplicatedDeclarations(dm.$resolvedFields, accept); + validateDuplicatedDeclarations(getModelFieldsWithBases(dm), accept); this.validateAttributes(dm, accept); this.validateFields(dm, accept); } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - const idFields = dm.$resolvedFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); - const uniqueFields = dm.$resolvedFields.filter((f) => - f.attributes.find((attr) => attr.decl.ref?.name === '@unique') - ); + const allFields = getModelFieldsWithBases(dm); + const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); + const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique')); const modelLevelIds = getModelIdFields(dm); const modelUniqueFields = getModelUniqueFields(dm); @@ -42,7 +47,7 @@ export default class DataModelValidator implements AstValidator { const { allows, denies, hasFieldValidation } = analyzePolicies(dm); if (allows.length > 0 || denies.length > 0 || hasFieldValidation) { // TODO: relax this requirement to require only @unique fields - // when access policies or field valdaition is used, require an @id field + // when access policies or field validation is used, require an @id field accept( 'error', 'Model must include a field with @id or @unique attribute, or a model-level @@id or @@unique attribute to use access policies', @@ -74,10 +79,10 @@ export default class DataModelValidator implements AstValidator { dm.fields.forEach((field) => this.validateField(field, accept)); if (!dm.isAbstract) { - dm.$resolvedFields + allFields .filter((x) => isDataModel(x.type.reference?.ref)) .forEach((y) => { - this.validateRelationField(y, accept); + this.validateRelationField(dm, y, accept); }); } } @@ -194,7 +199,7 @@ export default class DataModelValidator implements AstValidator { // points back const oppositeModel = field.type.reference?.ref as DataModel; if (oppositeModel) { - const oppositeModelFields = oppositeModel.$resolvedFields as DataModelField[]; + const oppositeModelFields = getModelFieldsWithBases(oppositeModel); for (const oppositeField of oppositeModelFields) { // find the opposite relation with the matching name const relAttr = oppositeField.attributes.find((a) => a.decl.ref?.name === '@relation'); @@ -213,7 +218,7 @@ export default class DataModelValidator implements AstValidator { return false; } - private validateRelationField(field: DataModelField, accept: ValidationAcceptor) { + private validateRelationField(contextModel: DataModel, field: DataModelField, accept: ValidationAcceptor) { const thisRelation = this.parseRelation(field, accept); if (!thisRelation.valid) { return; @@ -223,8 +228,8 @@ export default class DataModelValidator implements AstValidator { const oppositeModel = field.type.reference!.ref! as DataModel; // Use name because the current document might be updated - let oppositeFields = oppositeModel.$resolvedFields.filter( - (f) => f.type.reference?.ref?.name === field.$container.name + let oppositeFields = getModelFieldsWithBases(oppositeModel).filter( + (f) => f.type.reference?.ref?.name === contextModel.name ); oppositeFields = oppositeFields.filter((f) => { const fieldRel = this.parseRelation(f); @@ -232,13 +237,13 @@ export default class DataModelValidator implements AstValidator { }); if (oppositeFields.length === 0) { - const node = field.$isInherited ? field.$container : field; - const info: DiagnosticInfo = { node, code: IssueCodes.MissingOppositeRelation }; + const info: DiagnosticInfo = { + node: field, + code: IssueCodes.MissingOppositeRelation, + }; info.property = 'name'; - // use cstNode because the field might be inherited from parent model - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - const container = field.$cstNode!.element.$container as DataModel; + const container = field.$container; const relationFieldDocUri = getDocument(container).textDocument.uri; const relationDataModelName = container.name; @@ -247,20 +252,20 @@ export default class DataModelValidator implements AstValidator { relationFieldName: field.name, relationDataModelName, relationFieldDocUri, - dataModelName: field.$container.name, + dataModelName: contextModel.name, }; info.data = data; accept( 'error', - `The relation field "${field.name}" on model "${field.$container.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, + `The relation field "${field.name}" on model "${contextModel.name}" is missing an opposite relation field on model "${oppositeModel.name}"`, info ); return; } else if (oppositeFields.length > 1) { oppositeFields - .filter((x) => !x.$isInherited) + .filter((x) => !x.$inheritedFrom) .forEach((f) => { if (this.isSelfRelation(f)) { // self relations are partial diff --git a/packages/schema/src/language-server/validator/utils.ts b/packages/schema/src/language-server/validator/utils.ts index 50e2263d7..340f471b8 100644 --- a/packages/schema/src/language-server/validator/utils.ts +++ b/packages/schema/src/language-server/validator/utils.ts @@ -33,8 +33,8 @@ export function validateDuplicatedDeclarations( for (const [name, decls] of Object.entries(groupByName)) { if (decls.length > 1) { let errorField = decls[1]; - if (decls[0].$type === 'DataModelField') { - const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$isInherited); + if (isDataModelField(decls[0])) { + const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$inheritedFrom); if (nonInheritedFields.length > 0) { errorField = nonInheritedFields.slice(-1)[0]; } diff --git a/packages/schema/src/language-server/zmodel-code-action.ts b/packages/schema/src/language-server/zmodel-code-action.ts index aace4d0fe..8f60cbe69 100644 --- a/packages/schema/src/language-server/zmodel-code-action.ts +++ b/packages/schema/src/language-server/zmodel-code-action.ts @@ -2,18 +2,19 @@ import { DataModel, DataModelField, Model, isDataModel } from '@zenstackhq/langu import { AstReflection, CodeActionProvider, - getDocument, IndexManager, LangiumDocument, LangiumDocuments, LangiumServices, MaybePromise, + getDocument, } from 'langium'; +import { getModelFieldsWithBases } from '@zenstackhq/sdk'; import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver'; import { IssueCodes } from './constants'; -import { ZModelFormatter } from './zmodel-formatter'; import { MissingOppositeRelationData } from './validator/datamodel-validator'; +import { ZModelFormatter } from './zmodel-formatter'; export class ZModelCodeActionProvider implements CodeActionProvider { protected readonly reflection: AstReflection; @@ -92,8 +93,8 @@ export class ZModelCodeActionProvider implements CodeActionProvider { let newText = ''; if (fieldAstNode.type.array) { - //post Post[] - const idField = container.$resolvedFields.find((f) => + // post Post[] + const idField = getModelFieldsWithBases(container).find((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id') ) as DataModelField; @@ -111,7 +112,7 @@ export class ZModelCodeActionProvider implements CodeActionProvider { const idFieldName = idField.name; const referenceIdFieldName = fieldName + this.upperCaseFirstLetter(idFieldName); - if (!oppositeModel.$resolvedFields.find((f) => f.name === referenceIdFieldName)) { + if (!getModelFieldsWithBases(oppositeModel).find((f) => f.name === referenceIdFieldName)) { referenceField = '\n' + indent + `${referenceIdFieldName} ${idField.type.type}`; } diff --git a/packages/schema/src/language-server/zmodel-linker.ts b/packages/schema/src/language-server/zmodel-linker.ts index 8c8fb2c98..30929791f 100644 --- a/packages/schema/src/language-server/zmodel-linker.ts +++ b/packages/schema/src/language-server/zmodel-linker.ts @@ -35,7 +35,13 @@ import { isReferenceExpr, isStringLiteral, } from '@zenstackhq/language/ast'; -import { getContainingModel, hasAttribute, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk'; +import { + getContainingModel, + getModelFieldsWithBases, + hasAttribute, + isAuthInvocation, + isFutureExpr, +} from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -52,7 +58,7 @@ import { } from 'langium'; import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { getAllDeclarationsFromImports, getContainingDataModel, isCollectionPredicate } from '../utils/ast-utils'; +import { getAllDeclarationsFromImports, getContainingDataModel } from '../utils/ast-utils'; import { mapBuiltinTypeToExpressionType } from './validator/utils'; interface DefaultReference extends Reference { @@ -256,26 +262,9 @@ export class ZModelLinker extends DefaultLinker { } private resolveReference(node: ReferenceExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.linkReference(node, 'target', document, extraScopes); - node.args.forEach((arg) => this.resolve(arg, document, extraScopes)); + this.resolveDefault(node, document, extraScopes); if (node.target.ref) { - // if the reference is inside the RHS of a collection predicate, it cannot be resolve to a field - // not belonging to the collection's model type - - const collectionPredicateContext = this.getCollectionPredicateContextDataModel(node); - if ( - // inside a collection predicate RHS - collectionPredicateContext && - // current ref expr is resolved to a field - isDataModelField(node.target.ref) && - // the resolved field doesn't belong to the collection predicate's operand's type - node.target.ref.$container !== collectionPredicateContext - ) { - this.unresolvableRefExpr(node); - return; - } - // resolve type if (node.target.ref.$type === EnumField) { this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container); @@ -285,26 +274,6 @@ export class ZModelLinker extends DefaultLinker { } } - private getCollectionPredicateContextDataModel(node: ReferenceExpr) { - let curr: AstNode | undefined = node; - while (curr) { - if ( - curr.$container && - // parent is a collection predicate - isCollectionPredicate(curr.$container) && - // the collection predicate's LHS is resolved to a DataModel - isDataModel(curr.$container.left.$resolvedType?.decl) && - // current node is the RHS - curr.$containerProperty === 'right' - ) { - // return the resolved type of LHS - return curr.$container.left.$resolvedType?.decl; - } - curr = curr.$container; - } - return undefined; - } - private resolveArray(node: ArrayExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { node.items.forEach((item) => this.resolve(item, document, extraScopes)); @@ -367,14 +336,11 @@ export class ZModelLinker extends DefaultLinker { document: LangiumDocument, extraScopes: ScopeProvider[] ) { - this.resolve(node.operand, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const operandResolved = node.operand.$resolvedType; if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) { - const modelDecl = operandResolved.decl as DataModel; - const provider = (name: string) => modelDecl.$resolvedFields.find((f) => f.name === name); // member access is resolved only in the context of the operand type - this.linkReference(node, 'member', document, [provider], true); if (node.member.ref) { this.resolveToDeclaredType(node, node.member.ref.type); @@ -388,20 +354,10 @@ export class ZModelLinker extends DefaultLinker { } private resolveCollectionPredicate(node: BinaryExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) { - this.resolve(node.left, document, extraScopes); + this.resolveDefault(node, document, extraScopes); const resolvedType = node.left.$resolvedType; if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) { - const dataModelDecl = resolvedType.decl; - const provider = (name: string) => { - if (name === 'this') { - return dataModelDecl; - } else { - return dataModelDecl.$resolvedFields.find((f) => f.name === name); - } - }; - extraScopes = [provider, ...extraScopes]; - this.resolve(node.right, document, extraScopes); this.resolveToBuiltinTypeOrDecl(node, 'Boolean'); } else { // error is reported in validation pass @@ -455,10 +411,11 @@ export class ZModelLinker extends DefaultLinker { // // In model B, the attribute argument "myId" is resolved to the field "myId" in model A - const transtiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; - if (transtiveDataModel) { + const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel; + if (transitiveDataModel) { // resolve references in the context of the transitive data model - const scopeProvider = (name: string) => transtiveDataModel.$resolvedFields.find((f) => f.name === name); + const scopeProvider = (name: string) => + getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name); if (isArrayExpr(node.value)) { node.value.items.forEach((item) => { if (isReferenceExpr(item)) { diff --git a/packages/schema/src/language-server/zmodel-scope.ts b/packages/schema/src/language-server/zmodel-scope.ts index 8eda869e8..21304fa4a 100644 --- a/packages/schema/src/language-server/zmodel-scope.ts +++ b/packages/schema/src/language-server/zmodel-scope.ts @@ -1,7 +1,6 @@ import { - DataModel, + BinaryExpr, MemberAccessExpr, - Model, isDataModel, isDataModelField, isEnumField, @@ -9,8 +8,16 @@ import { isMemberAccessExpr, isModel, isReferenceExpr, + isThisExpr, } from '@zenstackhq/language/ast'; -import { getAuthModel, getDataModels } from '@zenstackhq/sdk'; +import { + getAuthModel, + getDataModels, + getModelFieldsWithBases, + getRecursiveBases, + isAuthInvocation, + isFutureExpr, +} from '@zenstackhq/sdk'; import { AstNode, AstNodeDescription, @@ -19,7 +26,6 @@ import { EMPTY_SCOPE, LangiumDocument, LangiumServices, - Mutable, PrecomputedScopes, ReferenceInfo, Scope, @@ -30,8 +36,9 @@ import { stream, streamAllContents, } from 'langium'; +import { match } from 'ts-pattern'; import { CancellationToken } from 'vscode-jsonrpc'; -import { resolveImportUri } from '../utils/ast-utils'; +import { isCollectionPredicate, resolveImportUri } from '../utils/ast-utils'; import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from './constants'; /** @@ -66,49 +73,18 @@ export class ZModelScopeComputation extends DefaultScopeComputation { return result; } - override computeLocalScopes( - document: LangiumDocument, - cancelToken?: CancellationToken | undefined - ): Promise { - const result = super.computeLocalScopes(document, cancelToken); - - //the $resolvedFields would be used in Linking stage for all the documents - //so we need to set it at the end of the scope computation - this.resolveBaseModels(document); - return result; - } - - private resolveBaseModels(document: LangiumDocument) { - const model = document.parseResult.value as Model; - - model.declarations.forEach((decl) => { - if (decl.$type === 'DataModel') { - const dataModel = decl as DataModel; - dataModel.$resolvedFields = [...dataModel.fields]; - this.getRecursiveSuperTypes(dataModel).forEach((superType) => { - superType.fields.forEach((field) => { - const cloneField = Object.assign({}, field); - cloneField.$isInherited = true; - const mutable = cloneField as Mutable; - // update container - mutable.$container = dataModel; - dataModel.$resolvedFields.push(cloneField); - }); - }); - } - }); - } + override processNode(node: AstNode, document: LangiumDocument, scopes: PrecomputedScopes) { + super.processNode(node, document, scopes); - private getRecursiveSuperTypes(dataModel: DataModel): DataModel[] { - const result: DataModel[] = []; - dataModel.superTypes.forEach((superType) => { - const superTypeDecl = superType.ref; - if (superTypeDecl) { - result.push(superTypeDecl); - result.push(...this.getRecursiveSuperTypes(superTypeDecl)); + if (isDataModel(node)) { + // add base fields to the scope recursively + const bases = getRecursiveBases(node); + for (const base of bases) { + for (const field of base.fields) { + scopes.add(node, this.descriptions.createDescription(field, this.nameProvider.getName(field))); + } } - }); - return result; + } } } @@ -140,50 +116,129 @@ export class ZModelScopeProvider extends DefaultScopeProvider { override getScope(context: ReferenceInfo): Scope { if (isMemberAccessExpr(context.container) && context.container.operand && context.property === 'member') { - return this.getMemberAccessScope(context.container); + return this.getMemberAccessScope(context); + } + + if (isReferenceExpr(context.container) && context.property === 'target') { + // when reference expression is resolved inside a collection predicate, the scope is the collection + const containerCollectionPredicate = getCollectionPredicateContext(context.container); + if (containerCollectionPredicate) { + return this.getCollectionPredicateScope(context, containerCollectionPredicate); + } } + return super.getScope(context); } - private getMemberAccessScope(node: MemberAccessExpr) { - if (isReferenceExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.target.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + private getMemberAccessScope(context: ReferenceInfo) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const node = context.container as MemberAccessExpr; + + return match(node.operand) + .when(isReferenceExpr, (operand) => { + // operand is a reference, it can only be a model field + const ref = operand.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isMemberAccessExpr(node.operand)) { - // scope to target model's fields - const ref = node.operand.member.ref; - if (isDataModelField(ref)) { - const targetModel = ref.type.reference?.ref; - if (isDataModel(targetModel)) { - return this.createScopeForNodes(targetModel.fields); + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (operand) => { + // operand is a member access, it must be resolved to a + const ref = operand.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); } - } - } else if (isInvocationExpr(node.operand)) { - // deal with member access from `auth()` and `future() - const funcName = node.operand.function.$refText; - if (funcName === 'auth') { - // resolve to `User` or `@@auth` model - const model = getContainerOfType(node, isModel); - if (model) { - const authModel = getAuthModel(getDataModels(model)); - if (authModel) { - return this.createScopeForNodes(authModel.fields); - } + return EMPTY_SCOPE; + }) + .when(isThisExpr, () => { + // operand is `this`, resolve to the containing model + return this.createScopeForContainingModel(node, globalScope); + }) + .when(isInvocationExpr, (operand) => { + // deal with member access from `auth()` and `future() + if (isAuthInvocation(operand)) { + // resolve to `User` or `@@auth` model + return this.createScopeForAuthModel(node, globalScope); } - } - if (funcName === 'future') { - const thisModel = getContainerOfType(node, isDataModel); - if (thisModel) { - return this.createScopeForNodes(thisModel.fields); + if (isFutureExpr(operand)) { + // resolve `future()` to the containing model + return this.createScopeForContainingModel(node, globalScope); } + return EMPTY_SCOPE; + }) + .otherwise(() => EMPTY_SCOPE); + } + + private getCollectionPredicateScope(context: ReferenceInfo, collectionPredicate: BinaryExpr) { + const referenceType = this.reflection.getReferenceType(context); + const globalScope = this.getGlobalScope(referenceType, context); + const collection = collectionPredicate.left; + + return match(collection) + .when(isReferenceExpr, (expr) => { + // collection is a reference, it can only be a model field + const ref = expr.target.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isMemberAccessExpr, (expr) => { + // collection is a member access, it can only be resolved to a model field + const ref = expr.member.ref; + if (isDataModelField(ref)) { + const targetModel = ref.type.reference?.ref; + return this.createScopeForModel(targetModel, globalScope); + } + return EMPTY_SCOPE; + }) + .when(isAuthInvocation, (expr) => { + return this.createScopeForAuthModel(expr, globalScope); + }) + .otherwise(() => EMPTY_SCOPE); + } + + private createScopeForContainingModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isDataModel); + if (model) { + return this.createScopeForNodes(model.fields, globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForModel(node: AstNode | undefined, globalScope: Scope) { + if (isDataModel(node)) { + return this.createScopeForNodes(getModelFieldsWithBases(node), globalScope); + } else { + return EMPTY_SCOPE; + } + } + + private createScopeForAuthModel(node: AstNode, globalScope: Scope) { + const model = getContainerOfType(node, isModel); + if (model) { + const authModel = getAuthModel(getDataModels(model, true)); + if (authModel) { + return this.createScopeForNodes(authModel.fields, globalScope); } } return EMPTY_SCOPE; } } + +function getCollectionPredicateContext(node: AstNode) { + let curr: AstNode | undefined = node; + while (curr) { + if (curr.$container && isCollectionPredicate(curr.$container) && curr.$containerProperty === 'right') { + return curr.$container; + } + curr = curr.$container; + } + return undefined; +} diff --git a/packages/schema/src/plugins/prisma/prisma-builder.ts b/packages/schema/src/plugins/prisma/prisma-builder.ts index 68336baeb..64777b62e 100644 --- a/packages/schema/src/plugins/prisma/prisma-builder.ts +++ b/packages/schema/src/plugins/prisma/prisma-builder.ts @@ -310,6 +310,7 @@ export class FunctionCallArg { return this.name ? `${this.name}: ${this.value}` : this.value; } } + export class Enum extends ContainerDeclaration { public fields: EnumField[] = []; diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 0f25ab1b8..d7a32ad30 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -30,10 +30,11 @@ import { match } from 'ts-pattern'; import { PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { + getAttribute, getDMMF, getLiteral, getPrismaVersion, - isDefaultAuthField, + isAuthInvocation, PluginError, PluginOptions, resolved, @@ -42,6 +43,7 @@ import { } from '@zenstackhq/sdk'; import fs from 'fs'; import { writeFile } from 'fs/promises'; +import { streamAst } from 'langium'; import path from 'path'; import semver from 'semver'; import stripColor from 'strip-color'; @@ -325,7 +327,7 @@ export default class PrismaSchemaGenerator { } private getAttributesToGenerate(field: DataModelField) { - if (isDefaultAuthField(field)) { + if (this.hasDefaultWithAuth(field)) { return []; } return field.attributes @@ -333,6 +335,21 @@ export default class PrismaSchemaGenerator { .map((attr) => this.makeFieldAttribute(attr)); } + private hasDefaultWithAuth(field: DataModelField) { + const defaultAttr = getAttribute(field, '@default'); + if (!defaultAttr) { + return false; + } + + const expr = defaultAttr.args[0]?.value; + if (!expr) { + return false; + } + + // find `auth()` in default value expression + return streamAst(expr).some(isAuthInvocation); + } + private makeFieldAttribute(attr: DataModelFieldAttribute) { const attrName = resolved(attr.decl).name; if (attrName === FIELD_PASSTHROUGH_ATTR) { diff --git a/packages/schema/src/telemetry.ts b/packages/schema/src/telemetry.ts index 3166a5f9b..9cd8ba386 100644 --- a/packages/schema/src/telemetry.ts +++ b/packages/schema/src/telemetry.ts @@ -8,8 +8,8 @@ import sleep from 'sleep-promise'; import { CliError } from './cli/cli-error'; import { TELEMETRY_TRACKING_TOKEN } from './constants'; import isDocker from './utils/is-docker'; -import { getVersion } from './utils/version-utils'; import { getMachineId } from './utils/machine-id-utils'; +import { getVersion } from './utils/version-utils'; /** * Telemetry events diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index 80543d6a2..1e2850577 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -2,13 +2,27 @@ import { BinaryExpr, DataModel, Expression, + InheritableNode, isBinaryExpr, isDataModel, isModel, Model, ModelImport, } from '@zenstackhq/language/ast'; -import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; +import { + AstNode, + CstNode, + GenericAstNode, + getContainerOfType, + getDocument, + isAstNode, + isReference, + LangiumDocuments, + linkContentToContainer, + Linker, + Mutable, + Reference, +} from 'langium'; import { URI, Utils } from 'vscode-uri'; export function extractDataModelsWithAllowRules(model: Model): DataModel[] { @@ -17,7 +31,16 @@ export function extractDataModelsWithAllowRules(model: Model): DataModel[] { ) as DataModel[]; } -export function mergeBaseModel(model: Model) { +type BuildReference = ( + node: AstNode, + property: string, + refNode: CstNode | undefined, + refText: string +) => Reference; + +export function mergeBaseModel(model: Model, linker: Linker) { + const buildReference = linker.buildReference.bind(linker); + model.declarations .filter((x) => x.$type === 'DataModel') .forEach((decl) => { @@ -25,27 +48,65 @@ export function mergeBaseModel(model: Model) { dataModel.fields = dataModel.superTypes // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.fields, dataModel)) + .flatMap((superType) => superType.ref!.fields) + .map((f) => cloneAst(f, dataModel, buildReference)) .concat(dataModel.fields); dataModel.attributes = dataModel.superTypes // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - .flatMap((superType) => updateContainer(superType.ref!.attributes, dataModel)) + .flatMap((superType) => superType.ref!.attributes) + .map((attr) => cloneAst(attr, dataModel, buildReference)) .concat(dataModel.attributes); }); // remove abstract models - model.declarations = model.declarations.filter((x) => !(x.$type == 'DataModel' && x.isAbstract)); + model.declarations = model.declarations.filter((x) => !(isDataModel(x) && x.isAbstract)); +} + +// deep clone an AST, relink references, and set its container +function cloneAst( + node: T, + newContainer: AstNode, + buildReference: BuildReference +): Mutable { + const clone = copyAstNode(node, buildReference) as Mutable; + clone.$container = newContainer; + clone.$containerProperty = node.$containerProperty; + clone.$containerIndex = node.$containerIndex; + clone.$inheritedFrom = getContainerOfType(node, isDataModel); + return clone; } -function updateContainer(nodes: T[], container: AstNode): Mutable[] { - return nodes.map((node) => { - const cloneField = Object.assign({}, node); - const mutable = cloneField as Mutable; - // update container - mutable.$container = container; - return mutable; - }); +// this function is copied from Langium's ast-utils, but copying $resolvedType as well +function copyAstNode(node: T, buildReference: BuildReference): T { + const copy: GenericAstNode = { $type: node.$type, $resolvedType: node.$resolvedType }; + + for (const [name, value] of Object.entries(node)) { + if (!name.startsWith('$')) { + if (isAstNode(value)) { + copy[name] = copyAstNode(value, buildReference); + } else if (isReference(value)) { + copy[name] = buildReference(copy, name, value.$refNode, value.$refText); + } else if (Array.isArray(value)) { + const copiedArray: unknown[] = []; + for (const element of value) { + if (isAstNode(element)) { + copiedArray.push(copyAstNode(element, buildReference)); + } else if (isReference(element)) { + copiedArray.push(buildReference(copy, name, element.$refNode, element.$refText)); + } else { + copiedArray.push(element); + } + } + copy[name] = copiedArray; + } else { + copy[name] = value; + } + } + } + + linkContentToContainer(copy); + return copy as unknown as T; } export function resolveImportUri(imp: ModelImport): URI | undefined { diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index cb2f788d4..dfc1d650c 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -1088,11 +1088,14 @@ describe('Attribute tests', () => { model A { id String @id x Int + b B @relation(references: [id], fields: [bId]) + bId String @unique } model B { id String @id - a A + a A? + aId String @unique @@allow('all', a?[x > 0]) } `) diff --git a/packages/schema/tests/utils.ts b/packages/schema/tests/utils.ts index f88aae6e2..4dcd45170 100644 --- a/packages/schema/tests/utils.ts +++ b/packages/schema/tests/utils.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true, mergeBase = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -52,7 +52,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; if (mergeBase) { - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); } return model; diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index 2f046b692..ed841dbc7 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -32,7 +32,7 @@ import { } from '@zenstackhq/language/ast'; import path from 'path'; import { ExpressionContext, STD_LIB_MODULE_NAME } from './constants'; -import { PluginDeclaredOptions, PluginError, PluginOptions } from './types'; +import { PluginError, type PluginDeclaredOptions, type PluginOptions } from './types'; /** * Gets data models that are not ignored @@ -281,13 +281,6 @@ export function isForeignKeyField(field: DataModelField) { }); } -export function isDefaultAuthField(field: DataModelField) { - return ( - hasAttribute(field, '@default') && - !!field.attributes.find((attr) => attr.args?.[0]?.value.$cstNode?.text.startsWith('auth()')) - ); -} - export function resolvePath(_path: string, options: Pick) { if (path.isAbsolute(_path)) { return _path; @@ -387,7 +380,7 @@ export function getAuthModel(dataModels: DataModel[]) { } export function getIdFields(dataModel: DataModel) { - const fieldLevelId = dataModel.$resolvedFields.find((f) => + const fieldLevelId = getModelFieldsWithBases(dataModel).find((f) => f.attributes.some((attr) => attr.decl.$refText === '@id') ); if (fieldLevelId) { @@ -418,3 +411,19 @@ export function getDataModelFieldReference(expr: Expression): DataModelField | u return undefined; } } + +export function getModelFieldsWithBases(model: DataModel) { + return [...model.fields, ...getRecursiveBases(model).flatMap((base) => base.fields)]; +} + +export function getRecursiveBases(dataModel: DataModel): DataModel[] { + const result: DataModel[] = []; + dataModel.superTypes.forEach((superType) => { + const baseDecl = superType.ref; + if (baseDecl) { + result.push(baseDecl); + result.push(...getRecursiveBases(baseDecl)); + } + }); + return result; +} diff --git a/packages/testtools/src/model.ts b/packages/testtools/src/model.ts index 4be8a1613..29b15467d 100644 --- a/packages/testtools/src/model.ts +++ b/packages/testtools/src/model.ts @@ -16,7 +16,7 @@ export class SchemaLoadingError extends Error { export async function loadModel(content: string, validate = true, verbose = true) { const { name: docPath } = tmp.fileSync({ postfix: '.zmodel' }); fs.writeFileSync(docPath, content); - const { shared } = createZModelServices(NodeFileSystem); + const { shared, ZModel } = createZModelServices(NodeFileSystem); const stdLib = shared.workspace.LangiumDocuments.getOrCreateDocument( URI.file(path.resolve(__dirname, '../../schema/src/res/stdlib.zmodel')) ); @@ -51,7 +51,7 @@ export async function loadModel(content: string, validate = true, verbose = true const model = (await doc.parseResult.value) as Model; - mergeBaseModel(model); + mergeBaseModel(model, ZModel.references.Linker); return model; } diff --git a/tests/integration/tests/regression/issue-925.test.ts b/tests/integration/tests/regression/issue-925.test.ts index 34b1ac434..b19d9d615 100644 --- a/tests/integration/tests/regression/issue-925.test.ts +++ b/tests/integration/tests/regression/issue-925.test.ts @@ -1,7 +1,7 @@ -import { loadModelWithError } from '@zenstackhq/testtools'; +import { loadModel, loadModelWithError } from '@zenstackhq/testtools'; describe('Regression: issue 925', () => { - it('member reference from this', async () => { + it('member reference without using this', async () => { await expect( loadModelWithError( ` @@ -10,7 +10,7 @@ describe('Regression: issue 925', () => { company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == this.test]]) + @@allow('read', auth().company?[staff?[companyId == test]]) } model Company { @@ -32,19 +32,18 @@ describe('Regression: issue 925', () => { } ` ) - ).resolves.toContain("Could not resolve reference to DataModelField named 'test'."); + ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); }); - it('simple reference', async () => { - await expect( - loadModelWithError( - ` + it('reference with this', async () => { + await loadModel( + ` model User { id Int @id @default(autoincrement()) company Company[] test Int - @@allow('read', auth().company?[staff?[companyId == test]]) + @@allow('read', auth().company?[staff?[companyId == this.test]]) } model Company { @@ -65,7 +64,6 @@ describe('Regression: issue 925', () => { @@allow('read', true) } ` - ) - ).resolves.toContain("Could not resolve reference to ReferenceTarget named 'test'."); + ); }); }); diff --git a/tests/integration/tests/regression/issues.test.ts b/tests/integration/tests/regression/issues.test.ts index 4ade85c8c..7c2ca94cd 100644 --- a/tests/integration/tests/regression/issues.test.ts +++ b/tests/integration/tests/regression/issues.test.ts @@ -327,9 +327,9 @@ model User { // can be created by anyone, even not logged in @@allow('create', true) // can be read by users in the same organization - @@allow('read', orgs?[members?[auth() == this]]) + @@allow('read', orgs?[members?[auth().id == id]]) // full access by oneself - @@allow('all', auth() == this) + @@allow('all', auth().id == id) } model Organization { @@ -343,7 +343,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -359,15 +359,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -394,7 +394,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` ); @@ -616,7 +616,7 @@ model Organization { // everyone can create a organization @@allow('create', true) // any user in the organization can read the organization - @@allow('read', members?[auth() == this]) + @@allow('read', members?[auth().id == id]) } abstract model organizationBaseEntity { @@ -632,15 +632,15 @@ abstract model organizationBaseEntity { groups Group[] // when create, owner must be set to current user, and user must be in the organization - @@allow('create', owner == auth() && org.members?[this == auth()]) + @@allow('create', owner == auth() && org.members?[id == auth().id]) // only the owner can update it and is not allowed to change the owner - @@allow('update', owner == auth() && org.members?[this == auth()] && future().owner == owner) + @@allow('update', owner == auth() && org.members?[id == auth().id] && future().owner == owner) // allow owner to read @@allow('read', owner == auth()) // allow shared group members to read it - @@allow('read', groups?[users?[this == auth()]]) + @@allow('read', groups?[users?[id == auth().id]]) // allow organization to access if public - @@allow('read', isPublic && org.members?[this == auth()]) + @@allow('read', isPublic && org.members?[id == auth().id]) // can not be read if deleted @@deny('all', isDeleted == true) } @@ -667,7 +667,7 @@ model Group { orgId String // group is shared by organization - @@allow('all', org.members?[auth() == this]) + @@allow('all', org.members?[auth().id == id]) } ` );