Skip to content

Commit dc106a9

Browse files
authored
fix: policy generation error when field-level rules contain "this" expression (#670)
1 parent 322eae8 commit dc106a9

File tree

15 files changed

+376
-67
lines changed

15 files changed

+376
-67
lines changed

packages/runtime/src/validation.ts

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,15 @@ export function hasAllFields(obj: any, fields: string[]) {
3232
}
3333
return fields.every((f) => obj[f] !== undefined && obj[f] !== null);
3434
}
35+
36+
/**
37+
* Check if the given objects have equal values for the given fields. Returns
38+
* false if either object is nullish or is not an object.
39+
*/
40+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
41+
export function allFieldsEqual(obj1: any, obj2: any, fields: string[]) {
42+
if (!obj1 || !obj2 || typeof obj1 !== 'object' || typeof obj2 !== 'object') {
43+
return false;
44+
}
45+
return fields.every((f) => obj1[f] === obj2[f]);
46+
}

packages/schema/src/language-server/validator/expression-validator.ts

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,14 @@
1-
import { BinaryExpr, Expression, ExpressionType, isBinaryExpr, isEnum } from '@zenstackhq/language/ast';
1+
import {
2+
BinaryExpr,
3+
Expression,
4+
ExpressionType,
5+
isBinaryExpr,
6+
isDataModel,
7+
isEnum,
8+
isNullExpr,
9+
isThisExpr,
10+
} from '@zenstackhq/language/ast';
11+
import { isDataModelFieldReference } from '@zenstackhq/sdk';
212
import { ValidationAcceptor } from 'langium';
313
import { isAuthInvocation } from '../../utils/ast-utils';
414
import { AstValidator } from '../types';
@@ -93,6 +103,43 @@ export default class ExpressionValidator implements AstValidator<Expression> {
93103

94104
break;
95105
}
106+
107+
case '==':
108+
case '!=': {
109+
// disallow comparing model type with scalar type or comparison between
110+
// incompatible model types
111+
const leftType = expr.left.$resolvedType?.decl;
112+
const rightType = expr.right.$resolvedType?.decl;
113+
if (isDataModel(leftType) && isDataModel(rightType)) {
114+
if (leftType != rightType) {
115+
// incompatible model types
116+
// TODO: inheritance case?
117+
accept('error', 'incompatible operand types', { node: expr });
118+
}
119+
120+
// not supported:
121+
// - foo == bar
122+
// - foo == this
123+
if (
124+
isDataModelFieldReference(expr.left) &&
125+
(isThisExpr(expr.right) || isDataModelFieldReference(expr.right))
126+
) {
127+
accept('error', 'comparison between model-typed fields are not supported', { node: expr });
128+
} else if (
129+
isDataModelFieldReference(expr.right) &&
130+
(isThisExpr(expr.left) || isDataModelFieldReference(expr.left))
131+
) {
132+
accept('error', 'comparison between model-typed fields are not supported', { node: expr });
133+
}
134+
} else if (
135+
(isDataModel(leftType) && !isNullExpr(expr.right)) ||
136+
(isDataModel(rightType) && !isNullExpr(expr.left))
137+
) {
138+
// comparing model against scalar (except null)
139+
accept('error', 'incompatible operand types', { node: expr });
140+
}
141+
break;
142+
}
96143
}
97144
}
98145

packages/schema/src/language-server/zmodel-linker.ts

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ interface DefaultReference extends Reference {
6161
_nodeDescription?: AstNodeDescription;
6262
}
6363

64-
type ScopeProvider = (name: string) => ReferenceTarget | undefined;
64+
type ScopeProvider = (name: string) => ReferenceTarget | DataModel | undefined;
6565

6666
/**
6767
* Langium linker implementation which links references and resolves expression types
@@ -342,7 +342,13 @@ export class ZModelLinker extends DefaultLinker {
342342
const resolvedType = node.left.$resolvedType;
343343
if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) {
344344
const dataModelDecl = resolvedType.decl;
345-
const provider = (name: string) => dataModelDecl.$resolvedFields.find((f) => f.name === name);
345+
const provider = (name: string) => {
346+
if (name === 'this') {
347+
return dataModelDecl;
348+
} else {
349+
return dataModelDecl.$resolvedFields.find((f) => f.name === name);
350+
}
351+
};
346352
extraScopes = [provider, ...extraScopes];
347353
this.resolve(node.right, document, extraScopes);
348354
this.resolveToBuiltinTypeOrDecl(node, 'Boolean');
@@ -351,13 +357,16 @@ export class ZModelLinker extends DefaultLinker {
351357
}
352358
}
353359

354-
private resolveThis(
355-
node: ThisExpr,
356-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
357-
document: LangiumDocument<AstNode>,
358-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
359-
extraScopes: ScopeProvider[]
360-
) {
360+
private resolveThis(node: ThisExpr, _document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
361+
// resolve from scopes first
362+
for (const scope of extraScopes) {
363+
const r = scope('this');
364+
if (isDataModel(r)) {
365+
this.resolveToBuiltinTypeOrDecl(node, r);
366+
return;
367+
}
368+
}
369+
361370
let decl: AstNode | undefined = node.$container;
362371

363372
while (decl && !isDataModel(decl)) {
@@ -369,13 +378,7 @@ export class ZModelLinker extends DefaultLinker {
369378
}
370379
}
371380

372-
private resolveNull(
373-
node: NullExpr,
374-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
375-
document: LangiumDocument<AstNode>,
376-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
377-
extraScopes: ScopeProvider[]
378-
) {
381+
private resolveNull(node: NullExpr, _document: LangiumDocument<AstNode>, _extraScopes: ScopeProvider[]) {
379382
// TODO: how to really resolve null?
380383
this.resolveToBuiltinTypeOrDecl(node, 'Null');
381384
}

packages/schema/src/plugins/access-policy/expression-writer.ts

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -279,18 +279,6 @@ export class ExpressionWriter {
279279
const leftIsFieldAccess = this.isFieldAccess(expr.left);
280280
const rightIsFieldAccess = this.isFieldAccess(expr.right);
281281

282-
if (leftIsFieldAccess && rightIsFieldAccess) {
283-
if (
284-
isDataModelFieldReference(expr.left) &&
285-
isDataModelFieldReference(expr.right) &&
286-
expr.left.target.ref?.$container === expr.right.target.ref?.$container
287-
) {
288-
// comparing fields from the same model
289-
} else {
290-
throw new PluginError(name, `Comparing fields from different models is not supported`);
291-
}
292-
}
293-
294282
if (!leftIsFieldAccess && !rightIsFieldAccess) {
295283
// compile down to a plain expression
296284
this.guard(() => {
@@ -318,7 +306,8 @@ export class ExpressionWriter {
318306
$container: fieldAccess.$container,
319307
target: fieldAccess.member,
320308
$resolvedType: fieldAccess.$resolvedType,
321-
} as ReferenceExpr;
309+
$future: true,
310+
} as unknown as ReferenceExpr;
322311
}
323312

324313
// guard member access of `auth()` with null check
@@ -349,10 +338,7 @@ export class ExpressionWriter {
349338
// right now this branch only serves comparison with `auth`, like
350339
// @@allow('all', owner == auth())
351340

352-
const idFields = getIdFields(dataModel);
353-
if (!idFields || idFields.length === 0) {
354-
throw new PluginError(name, `Data model ${dataModel.name} does not have an id field`);
355-
}
341+
const idFields = this.requireIdFields(dataModel);
356342

357343
if (operator !== '==' && operator !== '!=') {
358344
throw new PluginError(name, 'Only == and != operators are allowed');
@@ -389,15 +375,21 @@ export class ExpressionWriter {
389375
});
390376
});
391377
} else {
392-
this.writeOperator(operator, fieldAccess, () => {
393-
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
394-
// if operand is a field reference and we're not generating for post-update guard,
395-
// we should generate a field reference (comparing fields in the same model)
396-
this.writeFieldReference(operand);
397-
} else {
398-
this.plain(operand);
399-
}
400-
});
378+
if (this.equivalentRefs(fieldAccess, operand)) {
379+
// f == f or f != f
380+
// this == this or this != this
381+
this.writer.write(operator === '!=' ? TRUE : FALSE);
382+
} else {
383+
this.writeOperator(operator, fieldAccess, () => {
384+
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
385+
// if operand is a field reference and we're not generating for post-update guard,
386+
// we should generate a field reference (comparing fields in the same model)
387+
this.writeFieldReference(operand);
388+
} else {
389+
this.plain(operand);
390+
}
391+
});
392+
}
401393
}
402394
}, !isThisExpr(fieldAccess));
403395
});
@@ -408,6 +400,32 @@ export class ExpressionWriter {
408400
);
409401
}
410402

403+
private requireIdFields(dataModel: DataModel) {
404+
const idFields = getIdFields(dataModel);
405+
if (!idFields || idFields.length === 0) {
406+
throw new PluginError(name, `Data model ${dataModel.name} does not have an id field`);
407+
}
408+
return idFields;
409+
}
410+
411+
private equivalentRefs(expr1: Expression, expr2: Expression) {
412+
if (isThisExpr(expr1) && isThisExpr(expr2)) {
413+
return true;
414+
}
415+
416+
if (
417+
isReferenceExpr(expr1) &&
418+
isReferenceExpr(expr2) &&
419+
expr1.target.ref === expr2.target.ref &&
420+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
421+
(expr1 as any).$future === (expr2 as any).$future // either both future or both not
422+
) {
423+
return true;
424+
}
425+
426+
return false;
427+
}
428+
411429
// https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#compare-columns-in-the-same-table
412430
private writeFieldReference(expr: ReferenceExpr) {
413431
if (!expr.target.ref) {

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ export default class PolicyGenerator {
7979
{ name: 'type QueryContext' },
8080
{ name: 'type DbOperations' },
8181
{ name: 'hasAllFields' },
82+
{ name: 'allFieldsEqual' },
8283
{ name: 'type PolicyDef' },
8384
],
8485
moduleSpecifier: `${RUNTIME_PACKAGE}`,
@@ -486,6 +487,14 @@ export default class PolicyGenerator {
486487

487488
for (const rule of [...allows, ...denies]) {
488489
for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) {
490+
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
491+
// a standalone `this` expression, include all id fields
492+
const model = expr.$resolvedType?.decl as DataModel;
493+
const idFields = getIdFields(model);
494+
idFields.forEach((field) => addPath([field.name]));
495+
continue;
496+
}
497+
489498
// only care about member access and reference expressions
490499
if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) {
491500
continue;

packages/schema/src/utils/typescript-expression-transformer.ts

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import {
22
ArrayExpr,
33
BinaryExpr,
44
BooleanLiteral,
5+
DataModel,
56
Expression,
67
InvocationExpr,
78
isEnumField,
@@ -16,6 +17,7 @@ import {
1617
UnaryExpr,
1718
} from '@zenstackhq/language/ast';
1819
import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk';
20+
import { getIdFields } from './ast-utils';
1921

2022
export class TypeScriptExpressionTransformerError extends Error {
2123
constructor(message: string) {
@@ -94,10 +96,9 @@ export class TypeScriptExpressionTransformer {
9496
}
9597
}
9698

97-
// eslint-disable-next-line @typescript-eslint/no-unused-vars
98-
private this(expr: ThisExpr) {
99-
// "this" is mapped to id comparison
100-
return 'id';
99+
private this(_expr: ThisExpr) {
100+
// "this" is mapped to the input argument
101+
return 'input';
101102
}
102103

103104
private memberAccess(expr: MemberAccessExpr, normalizeUndefined: boolean) {
@@ -306,6 +307,23 @@ export class TypeScriptExpressionTransformer {
306307
expr.left,
307308
normalizeUndefined
308309
)}) ?? false)`;
310+
} else if (
311+
(expr.operator === '==' || expr.operator === '!=') &&
312+
(isThisExpr(expr.left) || isThisExpr(expr.right))
313+
) {
314+
// map equality comparison with `this` to id comparison
315+
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
316+
const model = _this.$resolvedType?.decl as DataModel;
317+
const idFields = getIdFields(model);
318+
if (!idFields || idFields.length === 0) {
319+
throw new TypeScriptExpressionTransformerError(`model "${model.name}" does not have an id field`);
320+
}
321+
let result = `allFieldsEqual(${this.transform(expr.left, false)},
322+
${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`;
323+
if (expr.operator === '!=') {
324+
result = `!${result}`;
325+
}
326+
return result;
309327
} else {
310328
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
311329
expr.right,

packages/schema/tests/generator/expression-writer.test.ts

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,26 +113,13 @@ describe('Expression Writer Tests', () => {
113113
it('this reference', async () => {
114114
await check(
115115
`
116-
model User { id String @id }
117116
model Test {
118117
id String @id
119-
@@allow('all', auth() == this)
118+
@@allow('all', this == this)
120119
}
121120
`,
122121
(model) => model.attributes[0].args[1].value,
123-
`(user == null) ? { OR: [] } : { id: user.id }`
124-
);
125-
126-
await check(
127-
`
128-
model User { id String @id }
129-
model Test {
130-
id String @id
131-
@@deny('all', this != auth())
132-
}
133-
`,
134-
(model) => model.attributes[0].args[1].value,
135-
`(user == null) ? { AND: [] } : { NOT: { id: user.id } }`
122+
`{OR:[]}`
136123
);
137124

138125
await check(

packages/schema/tests/generator/prisma-generator.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ describe('Prisma generator test', () => {
330330
const post = dmmf.datamodel.models.find((m) => m.name === 'Post');
331331

332332
expect(post?.documentation?.replace(/\s/g, '')).toBe(
333-
`@@allow('read', owner == auth()) @@allow('delete', ownerId == auth())`.replace(/\s/g, '')
333+
`@@allow('read', owner == auth()) @@allow('delete', owner == auth())`.replace(/\s/g, '')
334334
);
335335

336336
const todo = dmmf.datamodel.models.find((m) => m.name === 'Todo');

packages/schema/tests/generator/zmodel/schema.zmodel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ model Post extends Basic {
99
title String
1010
content String?
1111

12-
@@allow('delete', ownerId == auth())
12+
@@allow('delete', owner == auth())
1313
}
1414

1515
model Todo extends Basic {

0 commit comments

Comments
 (0)