Skip to content

refactor: simplify zmodel linking by improving scope computation; make AST cloning from base models more robust #957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions packages/language/src/ast.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { AbstractDeclaration, ExpressionType, BinaryExpr } from './generated/ast';
import { AstNode } from 'langium';
import { AbstractDeclaration, BinaryExpr, DataModel, ExpressionType } from './generated/ast';

export * from './generated/ast';
export { AstNode, Reference } from 'langium';
export * from './generated/ast';

/**
* Shape of type resolution result: an expression type or reference to a declaration
Expand Down Expand Up @@ -44,18 +45,19 @@ declare module './generated/ast' {
$resolvedParam?: AttributeParam;
}

interface DataModel {
/**
* Resolved fields, include inherited fields
*/
$resolvedFields: Array<DataModelField>;
interface DataModelField {
$inheritedFrom?: DataModel;
}

interface DataModelField {
$isInherited?: boolean;
interface DataModelAttribute {
$inheritedFrom?: DataModel;
}
}

export interface InheritableNode extends AstNode {
$inheritedFrom?: DataModel;
}

declare module 'langium' {
export interface AstNode {
/**
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/enhancements/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ import semver from 'semver';
import { PRISMA_MINIMUM_VERSION } from '../constants';
import { ModelMeta } from '../cross';
import type { AuthUser } from '../types';
import { withDefaultAuth } from './default-auth';
import { withOmit } from './omit';
import { withPassword } from './password';
import { withPolicy } from './policy';
import type { ErrorTransformer } from './proxy';
import type { PolicyDef, ZodSchemas } from './types';
import { withDefaultAuth } from './default-auth';

/**
* Kinds of enhancements to `PrismaClient`
Expand Down
15 changes: 0 additions & 15 deletions packages/runtime/src/enhancements/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,3 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo
export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error {
throw new prismaModule.PrismaClientUnknownRequestError(...args);
}

export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown {
if (path === undefined || path === '') {
return defaultValue;
}
const keys = Array.isArray(path) ? path : path.split('.');
for (const key of keys) {
if (object && typeof object === 'object' && key in object) {
object = object[key as keyof typeof object];
} else {
return defaultValue;
}
}
return object !== undefined ? object : defaultValue;
}
2 changes: 1 addition & 1 deletion packages/schema/src/cli/cli-util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ export async function loadDocument(fileName: string): Promise<Model> {

validationAfterMerge(model);

mergeBaseModel(model);
mergeBaseModel(model, services.references.Linker);

return model;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@ import {
isStringLiteral,
ReferenceExpr,
} from '@zenstackhq/language/ast';
import { analyzePolicies, getLiteral, getModelIdFields, getModelUniqueFields } from '@zenstackhq/sdk';
import {
analyzePolicies,
getLiteral,
getModelFieldsWithBases,
getModelIdFields,
getModelUniqueFields,
} from '@zenstackhq/sdk';
import { AstNode, DiagnosticInfo, getDocument, ValidationAcceptor } from 'langium';
import { IssueCodes, SCALAR_TYPES } from '../constants';
import { AstValidator } from '../types';
Expand All @@ -20,16 +26,15 @@ import { validateDuplicatedDeclarations } from './utils';
export default class DataModelValidator implements AstValidator<DataModel> {
validate(dm: DataModel, accept: ValidationAcceptor): void {
this.validateBaseAbstractModel(dm, accept);
validateDuplicatedDeclarations(dm.$resolvedFields, accept);
validateDuplicatedDeclarations(getModelFieldsWithBases(dm), accept);
this.validateAttributes(dm, accept);
this.validateFields(dm, accept);
}

private validateFields(dm: DataModel, accept: ValidationAcceptor) {
const idFields = dm.$resolvedFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id'));
const uniqueFields = dm.$resolvedFields.filter((f) =>
f.attributes.find((attr) => attr.decl.ref?.name === '@unique')
);
const allFields = getModelFieldsWithBases(dm);
const idFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id'));
const uniqueFields = allFields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@unique'));
const modelLevelIds = getModelIdFields(dm);
const modelUniqueFields = getModelUniqueFields(dm);

Expand All @@ -42,7 +47,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
const { allows, denies, hasFieldValidation } = analyzePolicies(dm);
if (allows.length > 0 || denies.length > 0 || hasFieldValidation) {
// TODO: relax this requirement to require only @unique fields
// when access policies or field valdaition is used, require an @id field
// when access policies or field validation is used, require an @id field
accept(
'error',
'Model must include a field with @id or @unique attribute, or a model-level @@id or @@unique attribute to use access policies',
Expand Down Expand Up @@ -74,10 +79,10 @@ export default class DataModelValidator implements AstValidator<DataModel> {
dm.fields.forEach((field) => this.validateField(field, accept));

if (!dm.isAbstract) {
dm.$resolvedFields
allFields
.filter((x) => isDataModel(x.type.reference?.ref))
.forEach((y) => {
this.validateRelationField(y, accept);
this.validateRelationField(dm, y, accept);
});
}
}
Expand Down Expand Up @@ -194,7 +199,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
// points back
const oppositeModel = field.type.reference?.ref as DataModel;
if (oppositeModel) {
const oppositeModelFields = oppositeModel.$resolvedFields as DataModelField[];
const oppositeModelFields = getModelFieldsWithBases(oppositeModel);
for (const oppositeField of oppositeModelFields) {
// find the opposite relation with the matching name
const relAttr = oppositeField.attributes.find((a) => a.decl.ref?.name === '@relation');
Expand All @@ -213,7 +218,7 @@ export default class DataModelValidator implements AstValidator<DataModel> {
return false;
}

private validateRelationField(field: DataModelField, accept: ValidationAcceptor) {
private validateRelationField(contextModel: DataModel, field: DataModelField, accept: ValidationAcceptor) {
const thisRelation = this.parseRelation(field, accept);
if (!thisRelation.valid) {
return;
Expand All @@ -223,22 +228,22 @@ export default class DataModelValidator implements AstValidator<DataModel> {
const oppositeModel = field.type.reference!.ref! as DataModel;

// Use name because the current document might be updated
let oppositeFields = oppositeModel.$resolvedFields.filter(
(f) => f.type.reference?.ref?.name === field.$container.name
let oppositeFields = getModelFieldsWithBases(oppositeModel).filter(
(f) => f.type.reference?.ref?.name === contextModel.name
);
oppositeFields = oppositeFields.filter((f) => {
const fieldRel = this.parseRelation(f);
return fieldRel.valid && fieldRel.name === thisRelation.name;
});

if (oppositeFields.length === 0) {
const node = field.$isInherited ? field.$container : field;
const info: DiagnosticInfo<AstNode, string> = { node, code: IssueCodes.MissingOppositeRelation };
const info: DiagnosticInfo<AstNode, string> = {
node: field,
code: IssueCodes.MissingOppositeRelation,
};

info.property = 'name';
// use cstNode because the field might be inherited from parent model
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
const container = field.$cstNode!.element.$container as DataModel;
const container = field.$container;

const relationFieldDocUri = getDocument(container).textDocument.uri;
const relationDataModelName = container.name;
Expand All @@ -247,20 +252,20 @@ export default class DataModelValidator implements AstValidator<DataModel> {
relationFieldName: field.name,
relationDataModelName,
relationFieldDocUri,
dataModelName: field.$container.name,
dataModelName: contextModel.name,
};

info.data = data;

accept(
'error',
`The relation field "${field.name}" on model "${field.$container.name}" is missing an opposite relation field on model "${oppositeModel.name}"`,
`The relation field "${field.name}" on model "${contextModel.name}" is missing an opposite relation field on model "${oppositeModel.name}"`,
info
);
return;
} else if (oppositeFields.length > 1) {
oppositeFields
.filter((x) => !x.$isInherited)
.filter((x) => !x.$inheritedFrom)
.forEach((f) => {
if (this.isSelfRelation(f)) {
// self relations are partial
Expand Down
4 changes: 2 additions & 2 deletions packages/schema/src/language-server/validator/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ export function validateDuplicatedDeclarations(
for (const [name, decls] of Object.entries<AstNode[]>(groupByName)) {
if (decls.length > 1) {
let errorField = decls[1];
if (decls[0].$type === 'DataModelField') {
const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$isInherited);
if (isDataModelField(decls[0])) {
const nonInheritedFields = decls.filter((x) => !(x as DataModelField).$inheritedFrom);
if (nonInheritedFields.length > 0) {
errorField = nonInheritedFields.slice(-1)[0];
}
Expand Down
11 changes: 6 additions & 5 deletions packages/schema/src/language-server/zmodel-code-action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ import { DataModel, DataModelField, Model, isDataModel } from '@zenstackhq/langu
import {
AstReflection,
CodeActionProvider,
getDocument,
IndexManager,
LangiumDocument,
LangiumDocuments,
LangiumServices,
MaybePromise,
getDocument,
} from 'langium';

import { getModelFieldsWithBases } from '@zenstackhq/sdk';
import { CodeAction, CodeActionKind, CodeActionParams, Command, Diagnostic } from 'vscode-languageserver';
import { IssueCodes } from './constants';
import { ZModelFormatter } from './zmodel-formatter';
import { MissingOppositeRelationData } from './validator/datamodel-validator';
import { ZModelFormatter } from './zmodel-formatter';

export class ZModelCodeActionProvider implements CodeActionProvider {
protected readonly reflection: AstReflection;
Expand Down Expand Up @@ -92,8 +93,8 @@ export class ZModelCodeActionProvider implements CodeActionProvider {

let newText = '';
if (fieldAstNode.type.array) {
//post Post[]
const idField = container.$resolvedFields.find((f) =>
// post Post[]
const idField = getModelFieldsWithBases(container).find((f) =>
f.attributes.find((attr) => attr.decl.ref?.name === '@id')
) as DataModelField;

Expand All @@ -111,7 +112,7 @@ export class ZModelCodeActionProvider implements CodeActionProvider {
const idFieldName = idField.name;
const referenceIdFieldName = fieldName + this.upperCaseFirstLetter(idFieldName);

if (!oppositeModel.$resolvedFields.find((f) => f.name === referenceIdFieldName)) {
if (!getModelFieldsWithBases(oppositeModel).find((f) => f.name === referenceIdFieldName)) {
referenceField = '\n' + indent + `${referenceIdFieldName} ${idField.type.type}`;
}

Expand Down
73 changes: 15 additions & 58 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ import {
isReferenceExpr,
isStringLiteral,
} from '@zenstackhq/language/ast';
import { getContainingModel, hasAttribute, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk';
import {
getContainingModel,
getModelFieldsWithBases,
hasAttribute,
isAuthInvocation,
isFutureExpr,
} from '@zenstackhq/sdk';
import {
AstNode,
AstNodeDescription,
Expand All @@ -52,7 +58,7 @@ import {
} from 'langium';
import { match } from 'ts-pattern';
import { CancellationToken } from 'vscode-jsonrpc';
import { getAllDeclarationsFromImports, getContainingDataModel, isCollectionPredicate } from '../utils/ast-utils';
import { getAllDeclarationsFromImports, getContainingDataModel } from '../utils/ast-utils';
import { mapBuiltinTypeToExpressionType } from './validator/utils';

interface DefaultReference extends Reference {
Expand Down Expand Up @@ -256,26 +262,9 @@ export class ZModelLinker extends DefaultLinker {
}

private resolveReference(node: ReferenceExpr, document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
this.linkReference(node, 'target', document, extraScopes);
node.args.forEach((arg) => this.resolve(arg, document, extraScopes));
this.resolveDefault(node, document, extraScopes);

if (node.target.ref) {
// if the reference is inside the RHS of a collection predicate, it cannot be resolve to a field
// not belonging to the collection's model type

const collectionPredicateContext = this.getCollectionPredicateContextDataModel(node);
if (
// inside a collection predicate RHS
collectionPredicateContext &&
// current ref expr is resolved to a field
isDataModelField(node.target.ref) &&
// the resolved field doesn't belong to the collection predicate's operand's type
node.target.ref.$container !== collectionPredicateContext
) {
this.unresolvableRefExpr(node);
return;
}

// resolve type
if (node.target.ref.$type === EnumField) {
this.resolveToBuiltinTypeOrDecl(node, node.target.ref.$container);
Expand All @@ -285,26 +274,6 @@ export class ZModelLinker extends DefaultLinker {
}
}

private getCollectionPredicateContextDataModel(node: ReferenceExpr) {
let curr: AstNode | undefined = node;
while (curr) {
if (
curr.$container &&
// parent is a collection predicate
isCollectionPredicate(curr.$container) &&
// the collection predicate's LHS is resolved to a DataModel
isDataModel(curr.$container.left.$resolvedType?.decl) &&
// current node is the RHS
curr.$containerProperty === 'right'
) {
// return the resolved type of LHS
return curr.$container.left.$resolvedType?.decl;
}
curr = curr.$container;
}
return undefined;
}

private resolveArray(node: ArrayExpr, document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
node.items.forEach((item) => this.resolve(item, document, extraScopes));

Expand Down Expand Up @@ -367,14 +336,11 @@ export class ZModelLinker extends DefaultLinker {
document: LangiumDocument<AstNode>,
extraScopes: ScopeProvider[]
) {
this.resolve(node.operand, document, extraScopes);
this.resolveDefault(node, document, extraScopes);
const operandResolved = node.operand.$resolvedType;

if (operandResolved && !operandResolved.array && isDataModel(operandResolved.decl)) {
const modelDecl = operandResolved.decl as DataModel;
const provider = (name: string) => modelDecl.$resolvedFields.find((f) => f.name === name);
// member access is resolved only in the context of the operand type
this.linkReference(node, 'member', document, [provider], true);
if (node.member.ref) {
this.resolveToDeclaredType(node, node.member.ref.type);

Expand All @@ -388,20 +354,10 @@ export class ZModelLinker extends DefaultLinker {
}

private resolveCollectionPredicate(node: BinaryExpr, document: LangiumDocument, extraScopes: ScopeProvider[]) {
this.resolve(node.left, document, extraScopes);
this.resolveDefault(node, document, extraScopes);

const resolvedType = node.left.$resolvedType;
if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) {
const dataModelDecl = resolvedType.decl;
const provider = (name: string) => {
if (name === 'this') {
return dataModelDecl;
} else {
return dataModelDecl.$resolvedFields.find((f) => f.name === name);
}
};
extraScopes = [provider, ...extraScopes];
this.resolve(node.right, document, extraScopes);
this.resolveToBuiltinTypeOrDecl(node, 'Boolean');
} else {
// error is reported in validation pass
Expand Down Expand Up @@ -455,10 +411,11 @@ export class ZModelLinker extends DefaultLinker {
//
// In model B, the attribute argument "myId" is resolved to the field "myId" in model A

const transtiveDataModel = attrAppliedOn.type.reference?.ref as DataModel;
if (transtiveDataModel) {
const transitiveDataModel = attrAppliedOn.type.reference?.ref as DataModel;
if (transitiveDataModel) {
// resolve references in the context of the transitive data model
const scopeProvider = (name: string) => transtiveDataModel.$resolvedFields.find((f) => f.name === name);
const scopeProvider = (name: string) =>
getModelFieldsWithBases(transitiveDataModel).find((f) => f.name === name);
if (isArrayExpr(node.value)) {
node.value.items.forEach((item) => {
if (isReferenceExpr(item)) {
Expand Down
Loading