Skip to content

feat: more flexible "in" operator and filter expressions #367

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast';
import { BinaryExpr, Expression, isBinaryExpr, isEnum } from '@zenstackhq/language/ast';
import { ValidationAcceptor } from 'langium';
import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils';
import { isAuthInvocation } from '../../utils/ast-utils';
import { AstValidator } from '../types';

/**
Expand Down Expand Up @@ -37,21 +37,12 @@ export default class ExpressionValidator implements AstValidator<Expression> {
private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) {
switch (expr.operator) {
case 'in': {
if (!getDataModelFieldReference(expr.left)) {
accept('error', 'left operand of "in" must be a field reference', { node: expr.left });
}

if (typeof expr.left.$resolvedType?.decl !== 'string' && !isEnum(expr.left.$resolvedType?.decl)) {
accept('error', 'left operand of "in" must be of scalar type', { node: expr.left });
}

if (
!(
isArrayExpr(expr.right) &&
expr.right.items.every((item) => isLiteralExpr(item) || isEnumFieldReference(item))
)
) {
accept('error', 'right operand of "in" must be an array of literals or enum values', {
if (!expr.right.$resolvedType?.array) {
accept('error', 'right operand of "in" must be an array', {
node: expr.right,
});
}
Expand Down
41 changes: 34 additions & 7 deletions packages/schema/src/plugins/access-policy/expression-writer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,35 @@ export class ExpressionWriter {
}

private writeIn(expr: BinaryExpr) {
const leftIsFieldAccess = this.isFieldAccess(expr.left);
const rightIsFieldAccess = this.isFieldAccess(expr.right);

this.block(() => {
this.writeFieldCondition(
expr.left,
() => {
this.plain(expr.right);
},
'in'
);
if (!leftIsFieldAccess && !rightIsFieldAccess) {
// 'in' without referencing fields
this.guard(() => this.plain(expr));
} else if (leftIsFieldAccess && !rightIsFieldAccess) {
// 'in' with left referencing a field, right is an array literal
this.writeFieldCondition(
expr.left,
() => {
this.plain(expr.right);
},
'in'
);
} else if (!leftIsFieldAccess && rightIsFieldAccess) {
// 'in' with right referencing an array field, left is a literal
// transform it into a 'has' filter
this.writeFieldCondition(
expr.right,
() => {
this.plain(expr.left);
},
'has'
);
} else {
throw new PluginError('"in" operator cannot be used with field references on both sides');
}
});
}

Expand Down Expand Up @@ -520,6 +541,12 @@ export class ExpressionWriter {
}

if (FILTER_OPERATOR_FUNCTIONS.includes(funcDecl.name)) {
if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) {
// filter functions without referencing fields
this.block(() => this.guard(() => this.plain(expr)));
return;
}

let valueArg = expr.args[1]?.value;

// isEmpty function is zero arity, it's mapped to a boolean literal
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ export default class PolicyGenerator {

const project = createProject();
const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true });
sf.addStatements('/* eslint-disable */');

sf.addImportDeclaration({
namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }],
Expand Down Expand Up @@ -361,7 +362,7 @@ export default class PolicyGenerator {
func.addStatements(
`const user = hasAllFields(context.user, [${userIdFields
.map((f) => "'" + f.name + "'")
.join(', ')}]) ? context.user : null;`
.join(', ')}]) ? context.user as any : null;`
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ import {
ThisExpr,
UnaryExpr,
} from '@zenstackhq/language/ast';
import { PluginError } from '@zenstackhq/sdk';
import { getLiteral, PluginError } from '@zenstackhq/sdk';
import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants';
import { isAuthInvocation } from '../../utils/ast-utils';
import { isFutureExpr } from './utils';

Expand All @@ -28,17 +29,17 @@ export default class TypeScriptExpressionTransformer {
constructor(private readonly isPostGuard = false) {}

/**
*
* @param expr
* Transforms the given expression to a TypeScript expression.
* @param normalizeUndefined if undefined values should be normalized to null
* @returns
*/
transform(expr: Expression): string {
transform(expr: Expression, normalizeUndefined = true): string {
switch (expr.$type) {
case LiteralExpr:
return this.literal(expr as LiteralExpr);

case ArrayExpr:
return this.array(expr as ArrayExpr);
return this.array(expr as ArrayExpr, normalizeUndefined);

case NullExpr:
return this.null();
Expand All @@ -50,16 +51,16 @@ export default class TypeScriptExpressionTransformer {
return this.reference(expr as ReferenceExpr);

case InvocationExpr:
return this.invocation(expr as InvocationExpr);
return this.invocation(expr as InvocationExpr, normalizeUndefined);

case MemberAccessExpr:
return this.memberAccess(expr as MemberAccessExpr);
return this.memberAccess(expr as MemberAccessExpr, normalizeUndefined);

case UnaryExpr:
return this.unary(expr as UnaryExpr);
return this.unary(expr as UnaryExpr, normalizeUndefined);

case BinaryExpr:
return this.binary(expr as BinaryExpr);
return this.binary(expr as BinaryExpr, normalizeUndefined);

default:
throw new PluginError(`Unsupported expression type: ${expr.$type}`);
Expand All @@ -72,7 +73,7 @@ export default class TypeScriptExpressionTransformer {
return 'id';
}

private memberAccess(expr: MemberAccessExpr) {
private memberAccess(expr: MemberAccessExpr, normalizeUndefined: boolean) {
if (!expr.member.ref) {
throw new PluginError(`Unresolved MemberAccessExpr`);
}
Expand All @@ -85,14 +86,71 @@ export default class TypeScriptExpressionTransformer {
}
return expr.member.ref.name;
} else {
// normalize field access to null instead of undefined to avoid accidentally use undefined in filter
return `(${this.transform(expr.operand)}?.${expr.member.ref.name} ?? null)`;
if (normalizeUndefined) {
// normalize field access to null instead of undefined to avoid accidentally use undefined in filter
return `(${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name} ?? null)`;
} else {
return `${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name}`;
}
}
}

private invocation(expr: InvocationExpr) {
private invocation(expr: InvocationExpr, normalizeUndefined: boolean) {
if (!expr.function.ref) {
throw new PluginError(`Unresolved InvocationExpr`);
}

if (isAuthInvocation(expr)) {
return 'user';
} else if (FILTER_OPERATOR_FUNCTIONS.includes(expr.function.ref.name)) {
// arguments are already type-checked

const arg0 = this.transform(expr.args[0].value, false);
let result: string;
switch (expr.function.ref.name) {
case 'contains': {
const caseInsensitive = getLiteral<boolean>(expr.args[2]?.value) === true;
if (caseInsensitive) {
result = `${arg0}?.toLowerCase().includes(${this.transform(
expr.args[1].value,
normalizeUndefined
)}?.toLowerCase())`;
} else {
result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`;
}
break;
}
case 'search':
throw new PluginError('"search" function must be used against a field');
case 'startsWith':
result = `${arg0}?.startsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`;
break;
case 'endsWith':
result = `${arg0}?.endsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`;
break;
case 'has':
result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`;
break;
case 'hasEvery':
result = `${this.transform(
expr.args[1].value,
normalizeUndefined
)}?.every((item) => ${arg0}?.includes(item))`;
break;
case 'hasSome':
result = `${this.transform(
expr.args[1].value,
normalizeUndefined
)}?.some((item) => ${arg0}?.includes(item))`;
break;
case 'isEmpty':
result = `${arg0}?.length === 0`;
break;
default:
throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`);
}

return `(${result} ?? false)`;
} else {
throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`);
}
Expand Down Expand Up @@ -121,8 +179,8 @@ export default class TypeScriptExpressionTransformer {
return 'null';
}

private array(expr: ArrayExpr) {
return `[${expr.items.map((item) => this.transform(item)).join(', ')}]`;
private array(expr: ArrayExpr, normalizeUndefined: boolean) {
return `[${expr.items.map((item) => this.transform(item, normalizeUndefined)).join(', ')}]`;
}

private literal(expr: LiteralExpr) {
Expand All @@ -133,11 +191,18 @@ export default class TypeScriptExpressionTransformer {
}
}

private unary(expr: UnaryExpr): string {
return `(${expr.operator} ${this.transform(expr.operand)})`;
private unary(expr: UnaryExpr, normalizeUndefined: boolean): string {
return `(${expr.operator} ${this.transform(expr.operand, normalizeUndefined)})`;
}

private binary(expr: BinaryExpr): string {
return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`;
private binary(expr: BinaryExpr, normalizeUndefined: boolean): string {
if (expr.operator === 'in') {
return `(${this.transform(expr.right, false)}?.includes(${this.transform(
expr.left,
normalizeUndefined
)}) ?? false)`;
} else {
return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`;
}
}
}
1 change: 1 addition & 0 deletions packages/schema/src/plugins/model-meta/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export default async function run(model: Model, options: PluginOptions) {
}

const sf = project.createSourceFile(path.join(output, 'model-meta.ts'), undefined, { overwrite: true });
sf.addStatements('/* eslint-disable */');
sf.addVariableStatement({
declarationKind: VariableDeclarationKind.Const,
declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(dataModels, writer) }],
Expand Down
5 changes: 3 additions & 2 deletions packages/schema/src/res/stdlib.zmodel
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ function future(): Any {
}

/*
* If the field value contains the search string
* If the field value contains the search string. By default, the search is case-sensitive,
* but you can override the behavior with the "caseInSensitive" argument.
*/
function contains(field: String, search: String, caseSensitive: Boolean?): Boolean {
function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean {
}

/*
Expand Down
Loading