Skip to content

fix: policy generation error when field-level rules contain "this" expression #670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions packages/runtime/src/validation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,15 @@ export function hasAllFields(obj: any, fields: string[]) {
}
return fields.every((f) => obj[f] !== undefined && obj[f] !== null);
}

/**
* Check if the given objects have equal values for the given fields. Returns
* false if either object is nullish or is not an object.
*/
// eslint-disable-next-line @typescript-eslint/no-explicit-any
export function allFieldsEqual(obj1: any, obj2: any, fields: string[]) {
if (!obj1 || !obj2 || typeof obj1 !== 'object' || typeof obj2 !== 'object') {
return false;
}
return fields.every((f) => obj1[f] === obj2[f]);
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
import { BinaryExpr, Expression, ExpressionType, isBinaryExpr, isEnum } from '@zenstackhq/language/ast';
import {
BinaryExpr,
Expression,
ExpressionType,
isBinaryExpr,
isDataModel,
isEnum,
isNullExpr,
isThisExpr,
} from '@zenstackhq/language/ast';
import { isDataModelFieldReference } from '@zenstackhq/sdk';
import { ValidationAcceptor } from 'langium';
import { isAuthInvocation } from '../../utils/ast-utils';
import { AstValidator } from '../types';
Expand Down Expand Up @@ -93,6 +103,43 @@ export default class ExpressionValidator implements AstValidator<Expression> {

break;
}

case '==':
case '!=': {
// disallow comparing model type with scalar type or comparison between
// incompatible model types
const leftType = expr.left.$resolvedType?.decl;
const rightType = expr.right.$resolvedType?.decl;
if (isDataModel(leftType) && isDataModel(rightType)) {
if (leftType != rightType) {
// incompatible model types
// TODO: inheritance case?
accept('error', 'incompatible operand types', { node: expr });
}

// not supported:
// - foo == bar
// - foo == this
if (
isDataModelFieldReference(expr.left) &&
(isThisExpr(expr.right) || isDataModelFieldReference(expr.right))
) {
accept('error', 'comparison between model-typed fields are not supported', { node: expr });
} else if (
isDataModelFieldReference(expr.right) &&
(isThisExpr(expr.left) || isDataModelFieldReference(expr.left))
) {
accept('error', 'comparison between model-typed fields are not supported', { node: expr });
}
} else if (
(isDataModel(leftType) && !isNullExpr(expr.right)) ||
(isDataModel(rightType) && !isNullExpr(expr.left))
) {
// comparing model against scalar (except null)
accept('error', 'incompatible operand types', { node: expr });
}
break;
}
}
}

Expand Down
35 changes: 19 additions & 16 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ interface DefaultReference extends Reference {
_nodeDescription?: AstNodeDescription;
}

type ScopeProvider = (name: string) => ReferenceTarget | undefined;
type ScopeProvider = (name: string) => ReferenceTarget | DataModel | undefined;

/**
* Langium linker implementation which links references and resolves expression types
Expand Down Expand Up @@ -342,7 +342,13 @@ export class ZModelLinker extends DefaultLinker {
const resolvedType = node.left.$resolvedType;
if (resolvedType && isDataModel(resolvedType.decl) && resolvedType.array) {
const dataModelDecl = resolvedType.decl;
const provider = (name: string) => dataModelDecl.$resolvedFields.find((f) => f.name === name);
const provider = (name: string) => {
if (name === 'this') {
return dataModelDecl;
} else {
return dataModelDecl.$resolvedFields.find((f) => f.name === name);
}
};
extraScopes = [provider, ...extraScopes];
this.resolve(node.right, document, extraScopes);
this.resolveToBuiltinTypeOrDecl(node, 'Boolean');
Expand All @@ -351,13 +357,16 @@ export class ZModelLinker extends DefaultLinker {
}
}

private resolveThis(
node: ThisExpr,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
document: LangiumDocument<AstNode>,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
extraScopes: ScopeProvider[]
) {
private resolveThis(node: ThisExpr, _document: LangiumDocument<AstNode>, extraScopes: ScopeProvider[]) {
// resolve from scopes first
for (const scope of extraScopes) {
const r = scope('this');
if (isDataModel(r)) {
this.resolveToBuiltinTypeOrDecl(node, r);
return;
}
}

let decl: AstNode | undefined = node.$container;

while (decl && !isDataModel(decl)) {
Expand All @@ -369,13 +378,7 @@ export class ZModelLinker extends DefaultLinker {
}
}

private resolveNull(
node: NullExpr,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
document: LangiumDocument<AstNode>,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
extraScopes: ScopeProvider[]
) {
private resolveNull(node: NullExpr, _document: LangiumDocument<AstNode>, _extraScopes: ScopeProvider[]) {
// TODO: how to really resolve null?
this.resolveToBuiltinTypeOrDecl(node, 'Null');
}
Expand Down
70 changes: 44 additions & 26 deletions packages/schema/src/plugins/access-policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,6 @@ export class ExpressionWriter {
const leftIsFieldAccess = this.isFieldAccess(expr.left);
const rightIsFieldAccess = this.isFieldAccess(expr.right);

if (leftIsFieldAccess && rightIsFieldAccess) {
if (
isDataModelFieldReference(expr.left) &&
isDataModelFieldReference(expr.right) &&
expr.left.target.ref?.$container === expr.right.target.ref?.$container
) {
// comparing fields from the same model
} else {
throw new PluginError(name, `Comparing fields from different models is not supported`);
}
}

if (!leftIsFieldAccess && !rightIsFieldAccess) {
// compile down to a plain expression
this.guard(() => {
Expand Down Expand Up @@ -318,7 +306,8 @@ export class ExpressionWriter {
$container: fieldAccess.$container,
target: fieldAccess.member,
$resolvedType: fieldAccess.$resolvedType,
} as ReferenceExpr;
$future: true,
} as unknown as ReferenceExpr;
}

// guard member access of `auth()` with null check
Expand Down Expand Up @@ -349,10 +338,7 @@ export class ExpressionWriter {
// right now this branch only serves comparison with `auth`, like
// @@allow('all', owner == auth())

const idFields = getIdFields(dataModel);
if (!idFields || idFields.length === 0) {
throw new PluginError(name, `Data model ${dataModel.name} does not have an id field`);
}
const idFields = this.requireIdFields(dataModel);

if (operator !== '==' && operator !== '!=') {
throw new PluginError(name, 'Only == and != operators are allowed');
Expand Down Expand Up @@ -389,15 +375,21 @@ export class ExpressionWriter {
});
});
} else {
this.writeOperator(operator, fieldAccess, () => {
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
// if operand is a field reference and we're not generating for post-update guard,
// we should generate a field reference (comparing fields in the same model)
this.writeFieldReference(operand);
} else {
this.plain(operand);
}
});
if (this.equivalentRefs(fieldAccess, operand)) {
// f == f or f != f
// this == this or this != this
this.writer.write(operator === '!=' ? TRUE : FALSE);
} else {
this.writeOperator(operator, fieldAccess, () => {
if (isDataModelFieldReference(operand) && !this.isPostGuard) {
// if operand is a field reference and we're not generating for post-update guard,
// we should generate a field reference (comparing fields in the same model)
this.writeFieldReference(operand);
} else {
this.plain(operand);
}
});
}
}
}, !isThisExpr(fieldAccess));
});
Expand All @@ -408,6 +400,32 @@ export class ExpressionWriter {
);
}

private requireIdFields(dataModel: DataModel) {
const idFields = getIdFields(dataModel);
if (!idFields || idFields.length === 0) {
throw new PluginError(name, `Data model ${dataModel.name} does not have an id field`);
}
return idFields;
}

private equivalentRefs(expr1: Expression, expr2: Expression) {
if (isThisExpr(expr1) && isThisExpr(expr2)) {
return true;
}

if (
isReferenceExpr(expr1) &&
isReferenceExpr(expr2) &&
expr1.target.ref === expr2.target.ref &&
// eslint-disable-next-line @typescript-eslint/no-explicit-any
(expr1 as any).$future === (expr2 as any).$future // either both future or both not
) {
return true;
}

return false;
}

// https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#compare-columns-in-the-same-table
private writeFieldReference(expr: ReferenceExpr) {
if (!expr.target.ref) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ export default class PolicyGenerator {
{ name: 'type QueryContext' },
{ name: 'type DbOperations' },
{ name: 'hasAllFields' },
{ name: 'allFieldsEqual' },
{ name: 'type PolicyDef' },
],
moduleSpecifier: `${RUNTIME_PACKAGE}`,
Expand Down Expand Up @@ -486,6 +487,14 @@ export default class PolicyGenerator {

for (const rule of [...allows, ...denies]) {
for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) {
if (isThisExpr(expr) && !isMemberAccessExpr(expr.$container)) {
// a standalone `this` expression, include all id fields
const model = expr.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
idFields.forEach((field) => addPath([field.name]));
continue;
}

// only care about member access and reference expressions
if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) {
continue;
Expand Down
26 changes: 22 additions & 4 deletions packages/schema/src/utils/typescript-expression-transformer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import {
ArrayExpr,
BinaryExpr,
BooleanLiteral,
DataModel,
Expression,
InvocationExpr,
isEnumField,
Expand All @@ -16,6 +17,7 @@ import {
UnaryExpr,
} from '@zenstackhq/language/ast';
import { ExpressionContext, getLiteral, isFromStdlib, isFutureExpr } from '@zenstackhq/sdk';
import { getIdFields } from './ast-utils';

export class TypeScriptExpressionTransformerError extends Error {
constructor(message: string) {
Expand Down Expand Up @@ -94,10 +96,9 @@ export class TypeScriptExpressionTransformer {
}
}

// eslint-disable-next-line @typescript-eslint/no-unused-vars
private this(expr: ThisExpr) {
// "this" is mapped to id comparison
return 'id';
private this(_expr: ThisExpr) {
// "this" is mapped to the input argument
return 'input';
}

private memberAccess(expr: MemberAccessExpr, normalizeUndefined: boolean) {
Expand Down Expand Up @@ -306,6 +307,23 @@ export class TypeScriptExpressionTransformer {
expr.left,
normalizeUndefined
)}) ?? false)`;
} else if (
(expr.operator === '==' || expr.operator === '!=') &&
(isThisExpr(expr.left) || isThisExpr(expr.right))
) {
// map equality comparison with `this` to id comparison
const _this = isThisExpr(expr.left) ? expr.left : expr.right;
const model = _this.$resolvedType?.decl as DataModel;
const idFields = getIdFields(model);
if (!idFields || idFields.length === 0) {
throw new TypeScriptExpressionTransformerError(`model "${model.name}" does not have an id field`);
}
let result = `allFieldsEqual(${this.transform(expr.left, false)},
${this.transform(expr.right, false)}, [${idFields.map((f) => "'" + f.name + "'").join(', ')}])`;
if (expr.operator === '!=') {
result = `!${result}`;
}
return result;
} else {
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
expr.right,
Expand Down
17 changes: 2 additions & 15 deletions packages/schema/tests/generator/expression-writer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,26 +113,13 @@ describe('Expression Writer Tests', () => {
it('this reference', async () => {
await check(
`
model User { id String @id }
model Test {
id String @id
@@allow('all', auth() == this)
@@allow('all', this == this)
}
`,
(model) => model.attributes[0].args[1].value,
`(user == null) ? { OR: [] } : { id: user.id }`
);

await check(
`
model User { id String @id }
model Test {
id String @id
@@deny('all', this != auth())
}
`,
(model) => model.attributes[0].args[1].value,
`(user == null) ? { AND: [] } : { NOT: { id: user.id } }`
`{OR:[]}`
);

await check(
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/tests/generator/prisma-generator.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ describe('Prisma generator test', () => {
const post = dmmf.datamodel.models.find((m) => m.name === 'Post');

expect(post?.documentation?.replace(/\s/g, '')).toBe(
`@@allow('read', owner == auth()) @@allow('delete', ownerId == auth())`.replace(/\s/g, '')
`@@allow('read', owner == auth()) @@allow('delete', owner == auth())`.replace(/\s/g, '')
);

const todo = dmmf.datamodel.models.find((m) => m.name === 'Todo');
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/tests/generator/zmodel/schema.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ model Post extends Basic {
title String
content String?

@@allow('delete', ownerId == auth())
@@allow('delete', owner == auth())
}

model Todo extends Basic {
Expand Down
Loading