Skip to content

Commit b8a875e

Browse files
authored
fix: fix policy generation for collection predicate expressions (#706)
1 parent 2d41a9f commit b8a875e

File tree

7 files changed

+229
-71
lines changed

7 files changed

+229
-71
lines changed

packages/schema/src/language-server/validator/attribute-application-validator.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import {
1616
isReferenceExpr,
1717
} from '@zenstackhq/language/ast';
1818
import { isFutureExpr, resolved } from '@zenstackhq/sdk';
19-
import { ValidationAcceptor, streamAllContents } from 'langium';
19+
import { ValidationAcceptor, streamAst } from 'langium';
2020
import pluralize from 'pluralize';
2121
import { AstValidator } from '../types';
2222
import { getStringLiteral, mapBuiltinTypeToExpressionType, typeAssignable } from './utils';
@@ -134,7 +134,7 @@ export default class AttributeApplicationValidator implements AstValidator<Attri
134134
this.validatePolicyKinds(kind, ['read', 'update', 'all'], attr, accept);
135135

136136
const expr = attr.args[1].value;
137-
if ([expr, ...streamAllContents(expr)].some((node) => isFutureExpr(node))) {
137+
if (streamAst(expr).some((node) => isFutureExpr(node))) {
138138
accept('error', `"future()" is not allowed in field-level policy rules`, { node: expr });
139139
}
140140
}

packages/schema/src/language-server/validator/expression-validator.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@ import {
22
BinaryExpr,
33
Expression,
44
ExpressionType,
5-
isBinaryExpr,
65
isDataModel,
76
isEnum,
87
isNullExpr,
98
isThisExpr,
109
} from '@zenstackhq/language/ast';
1110
import { isDataModelFieldReference } from '@zenstackhq/sdk';
1211
import { ValidationAcceptor } from 'langium';
13-
import { isAuthInvocation } from '../../utils/ast-utils';
12+
import { isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils';
1413
import { AstValidator } from '../types';
1514

1615
/**
@@ -23,7 +22,7 @@ export default class ExpressionValidator implements AstValidator<Expression> {
2322
if (isAuthInvocation(expr)) {
2423
// check was done at link time
2524
accept('error', 'auth() cannot be resolved because no "User" model is defined', { node: expr });
26-
} else if (this.isCollectionPredicate(expr)) {
25+
} else if (isCollectionPredicate(expr)) {
2726
accept('error', 'collection predicate can only be used on an array of model type', { node: expr });
2827
} else {
2928
accept('error', 'expression cannot be resolved', {
@@ -142,8 +141,4 @@ export default class ExpressionValidator implements AstValidator<Expression> {
142141
}
143142
}
144143
}
145-
146-
private isCollectionPredicate(expr: Expression) {
147-
return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator);
148-
}
149144
}

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import {
55
DataModelFieldAttribute,
66
Enum,
77
Expression,
8-
MemberAccessExpr,
98
Model,
109
isBinaryExpr,
1110
isDataModel,
@@ -49,12 +48,12 @@ import {
4948
resolved,
5049
saveProject,
5150
} from '@zenstackhq/sdk';
52-
import { streamAllContents } from 'langium';
51+
import { streamAllContents, streamAst, streamContents } from 'langium';
5352
import { lowerCaseFirst } from 'lower-case-first';
5453
import path from 'path';
5554
import { FunctionDeclaration, SourceFile, VariableDeclarationKind, WriterFunction } from 'ts-morph';
5655
import { name } from '.';
57-
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
56+
import { getIdFields, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils';
5857
import {
5958
TypeScriptExpressionTransformer,
6059
TypeScriptExpressionTransformerError,
@@ -237,7 +236,7 @@ export default class PolicyGenerator {
237236
}
238237

239238
private hasFutureReference(expr: Expression) {
240-
for (const node of this.allNodes(expr)) {
239+
for (const node of streamAst(expr)) {
241240
if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) {
242241
return true;
243242
}
@@ -434,7 +433,7 @@ export default class PolicyGenerator {
434433

435434
private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) {
436435
return [...allows, ...denies].every((rule) => {
437-
return [...this.allNodes(rule)].every((expr) => {
436+
return streamAst(rule).every((expr) => {
438437
if (isThisExpr(expr)) {
439438
return false;
440439
}
@@ -487,6 +486,8 @@ export default class PolicyGenerator {
487486
});
488487
};
489488

489+
// visit a reference or member access expression to build a
490+
// selection path
490491
const visit = (node: Expression): string[] | undefined => {
491492
if (isReferenceExpr(node)) {
492493
const target = resolved(node.target);
@@ -509,35 +510,50 @@ export default class PolicyGenerator {
509510
return undefined;
510511
};
511512

512-
for (const rule of [...allows, ...denies]) {
513-
for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) {
514-
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
515-
// a standalone `this` expression, include all id fields
516-
const model = expr.$resolvedType?.decl as DataModel;
517-
const idFields = getIdFields(model);
518-
idFields.forEach((field) => addPath([field.name]));
519-
continue;
520-
}
521-
522-
// only care about member access and reference expressions
523-
if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) {
524-
continue;
525-
}
526-
527-
if (expr.$container.$type === MemberAccessExpr) {
528-
// only visit top-level member access
529-
continue;
530-
}
513+
// collect selection paths from the given expression
514+
const collectReferencePaths = (expr: Expression): string[][] => {
515+
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
516+
// a standalone `this` expression, include all id fields
517+
const model = expr.$resolvedType?.decl as DataModel;
518+
const idFields = getIdFields(model);
519+
return idFields.map((field) => [field.name]);
520+
}
531521

522+
if (isMemberAccessExpr(expr) || isReferenceExpr(expr)) {
532523
const path = visit(expr);
533524
if (path) {
534525
if (isDataModel(expr.$resolvedType?.decl)) {
535-
// member selection ended at a data model field, include its 'id'
536-
path.push('id');
526+
// member selection ended at a data model field, include its id fields
527+
const idFields = getIdFields(expr.$resolvedType?.decl as DataModel);
528+
return idFields.map((field) => [...path, field.name]);
529+
} else {
530+
return [path];
537531
}
538-
addPath(path);
532+
} else {
533+
return [];
539534
}
535+
} else if (isCollectionPredicate(expr)) {
536+
const path = visit(expr.left);
537+
if (path) {
538+
// recurse into RHS
539+
const rhs = collectReferencePaths(expr.right);
540+
// combine path of LHS and RHS
541+
return rhs.map((r) => [...path, ...r]);
542+
} else {
543+
return [];
544+
}
545+
} else {
546+
// recurse
547+
const children = streamContents(expr)
548+
.filter((child): child is Expression => isExpression(child))
549+
.toArray();
550+
return children.flatMap((child) => collectReferencePaths(child));
540551
}
552+
};
553+
554+
for (const rule of [...allows, ...denies]) {
555+
const paths = collectReferencePaths(rule);
556+
paths.forEach((p) => addPath(p));
541557
}
542558

543559
return Object.keys(result).length === 0 ? undefined : result;
@@ -556,7 +572,7 @@ export default class PolicyGenerator {
556572
this.generateNormalizedAuthRef(model, allows, denies, statements);
557573

558574
const hasFieldAccess = [...denies, ...allows].some((rule) =>
559-
[...this.allNodes(rule)].some(
575+
streamAst(rule).some(
560576
(child) =>
561577
// this.???
562578
isThisExpr(child) ||
@@ -724,7 +740,7 @@ export default class PolicyGenerator {
724740
) {
725741
// check if any allow or deny rule contains 'auth()' invocation
726742
const hasAuthRef = [...allows, ...denies].some((rule) =>
727-
[...this.allNodes(rule)].some((child) => isAuthInvocation(child))
743+
streamAst(rule).some((child) => isAuthInvocation(child))
728744
);
729745

730746
if (hasAuthRef) {
@@ -747,9 +763,4 @@ export default class PolicyGenerator {
747763
);
748764
}
749765
}
750-
751-
private *allNodes(expr: Expression) {
752-
yield expr;
753-
yield* streamAllContents(expr);
754-
}
755766
}

packages/schema/src/utils/ast-utils.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import {
2+
BinaryExpr,
23
DataModel,
34
DataModelField,
45
Expression,
56
isArrayExpr,
7+
isBinaryExpr,
68
isDataModel,
79
isDataModelField,
810
isInvocationExpr,
@@ -150,3 +152,7 @@ export function getAllDeclarationsFromImports(documents: LangiumDocuments, model
150152
const imports = resolveTransitiveImports(documents, model);
151153
return model.declarations.concat(...imports.map((imp) => imp.declarations));
152154
}
155+
156+
export function isCollectionPredicate(expr: Expression): expr is BinaryExpr {
157+
return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator);
158+
}

packages/schema/src/utils/typescript-expression-transformer.ts

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
UnaryExpr,
1818
} from '@zenstackhq/language/ast';
1919
import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk';
20+
import { match, P } from 'ts-pattern';
2021
import { getIdFields } from './ast-utils';
2122

2223
export class TypeScriptExpressionTransformerError extends Error {
@@ -53,7 +54,7 @@ export class TypeScriptExpressionTransformer {
5354
*
5455
* @param isPostGuard indicates if we're writing for post-update conditions
5556
*/
56-
constructor(private readonly options?: Options) {}
57+
constructor(private readonly options: Options) {}
5758

5859
/**
5960
* Transforms the given expression to a TypeScript expression.
@@ -302,33 +303,57 @@ export class TypeScriptExpressionTransformer {
302303
}
303304

304305
private binary(expr: BinaryExpr, normalizeUndefined: boolean): string {
305-
if (expr.operator === 'in') {
306-
return `(${this.transform(expr.right, false)}?.includes(${this.transform(
307-
expr.left,
308-
normalizeUndefined
309-
)}) ?? false)`;
310-
} else if (
311-
(expr.operator === '==' || expr.operator === '!=') &&
312-
(isThisExpr(expr.left) || isThisExpr(expr.right))
313-
) {
314-
// map equality comparison with `this` to id comparison
315-
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
316-
const model = _this.$resolvedType?.decl as DataModel;
317-
const idFields = getIdFields(model);
318-
if (!idFields || idFields.length === 0) {
319-
throw new TypeScriptExpressionTransformerError(`model "${model.name}" does not have an id field`);
320-
}
321-
let result = `allFieldsEqual(${this.transform(expr.left, false)},
306+
const _default = `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
307+
expr.right,
308+
normalizeUndefined
309+
)})`;
310+
311+
return match(expr.operator)
312+
.with(
313+
'in',
314+
() =>
315+
`(${this.transform(expr.right, false)}?.includes(${this.transform(
316+
expr.left,
317+
normalizeUndefined
318+
)}) ?? false)`
319+
)
320+
.with(P.union('==', '!='), () => {
321+
if (isThisExpr(expr.left) || isThisExpr(expr.right)) {
322+
// map equality comparison with `this` to id comparison
323+
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
324+
const model = _this.$resolvedType?.decl as DataModel;
325+
const idFields = getIdFields(model);
326+
if (!idFields || idFields.length === 0) {
327+
throw new TypeScriptExpressionTransformerError(
328+
`model "${model.name}" does not have an id field`
329+
);
330+
}
331+
let result = `allFieldsEqual(${this.transform(expr.left, false)},
322332
${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`;
323-
if (expr.operator === '!=') {
324-
result = `!${result}`;
325-
}
326-
return result;
327-
} else {
328-
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
329-
expr.right,
330-
normalizeUndefined
331-
)})`;
332-
}
333+
if (expr.operator === '!=') {
334+
result = `!${result}`;
335+
}
336+
return result;
337+
} else {
338+
return _default;
339+
}
340+
})
341+
.with(P.union('?', '!', '^'), (op) => this.collectionPredicate(expr, op, normalizeUndefined))
342+
.otherwise(() => _default);
343+
}
344+
345+
private collectionPredicate(expr: BinaryExpr, operator: '?' | '!' | '^', normalizeUndefined: boolean) {
346+
const operand = this.transform(expr.left, normalizeUndefined);
347+
const innerTransformer = new TypeScriptExpressionTransformer({
348+
...this.options,
349+
fieldReferenceContext: '_item',
350+
});
351+
const predicate = innerTransformer.transform(expr.right, normalizeUndefined);
352+
353+
return match(operator)
354+
.with('?', () => `!!((${operand})?.some((_item: any) => ${predicate}))`)
355+
.with('!', () => `!!((${operand})?.every((_item: any) => ${predicate}))`)
356+
.with('^', () => `!((${operand})?.some((_item: any) => ${predicate}))`)
357+
.exhaustive();
333358
}
334359
}

0 commit comments

Comments
 (0)