Skip to content

fix: short-circuit post-read check when policy rules don't depend on model fields #376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "zenstack-monorepo",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"description": "",
"scripts": {
"build": "pnpm -r build",
Expand Down
2 changes: 1 addition & 1 deletion packages/language/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/language",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"displayName": "ZenStack modeling language compiler",
"description": "ZenStack modeling language compiler",
"homepage": "https://zenstack.dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/next/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/next",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"displayName": "ZenStack Next.js integration",
"description": "ZenStack Next.js integration",
"homepage": "https://zenstack.dev",
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/openapi/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/openapi",
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"description": "ZenStack plugin and runtime supporting OpenAPI",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/react/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/react",
"displayName": "ZenStack plugin and runtime for ReactJS",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"description": "ZenStack plugin and runtime for ReactJS",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/trpc/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/trpc",
"displayName": "ZenStack plugin for tRPC",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"description": "ZenStack plugin for tRPC",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/runtime",
"displayName": "ZenStack Runtime Library",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"description": "Runtime of ZenStack for both client-side and server-side environments.",
"repository": {
"type": "git",
Expand Down
24 changes: 21 additions & 3 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,10 @@ export class PolicyUtil {
* omitted.
*/
async postProcessForRead(entityData: any, model: string, args: any, operation: PolicyOperationKind) {
if (typeof entityData !== 'object' || !entityData) {
return;
}

const ids = this.getEntityIds(model, entityData);
if (Object.keys(ids).length === 0) {
return;
Expand Down Expand Up @@ -739,6 +743,14 @@ export class PolicyUtil {
operation: PolicyOperationKind,
db: Record<string, DbOperations>
) {
const guard = await this.getAuthGuard(model, operation);
const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model));

if (guard === true && !schema) {
// unconditionally allowed
return;
}

// DEBUG
// this.logger.info(`Checking policy for ${model}#${JSON.stringify(filter)} for ${operation}`);

Expand All @@ -750,13 +762,19 @@ export class PolicyUtil {
await this.flattenGeneratedUniqueField(model, queryFilter);

const count = (await db[model].count({ where: queryFilter })) as number;
const guard = await this.getAuthGuard(model, operation);
if (count === 0) {
// there's nothing to filter out
return;
}

if (guard === false) {
// unconditionally denied
throw this.deniedByPolicy(model, operation, `${count} ${pluralize('entity', count)} failed policy check`);
}

// build a query condition with policy injected
const guardedQuery = { where: this.and(queryFilter, guard) };

const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model));

if (schema) {
// we've got schemas, so have to fetch entities and validate them
const entities = await db[model].findMany(guardedQuery);
Expand Down
2 changes: 1 addition & 1 deletion packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"publisher": "zenstack",
"displayName": "ZenStack Language Tools",
"description": "A toolkit for building secure CRUD apps with Next.js + Typescript",
"version": "1.0.0-alpha.112",
"version": "1.0.0-alpha.113",
"author": {
"name": "ZenStack Team"
},
Expand Down
123 changes: 71 additions & 52 deletions packages/schema/src/plugins/access-policy/policy-guard-generator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {
isInvocationExpr,
isMemberAccessExpr,
isReferenceExpr,
isThisExpr,
isUnaryExpr,
MemberAccessExpr,
Model,
Expand All @@ -33,9 +34,10 @@ import path from 'path';
import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph';
import { name } from '.';
import { isFromStdlib } from '../../language-server/utils';
import { getIdFields } from '../../utils/ast-utils';
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils';
import { ExpressionWriter } from './expression-writer';
import TypeScriptExpressionTransformer from './typescript-expression-transformer';
import { isFutureExpr } from './utils';
import { ZodSchemaGenerator } from './zod-schema-generator';

Expand Down Expand Up @@ -332,18 +334,9 @@ export default class PolicyGenerator {
.addBody();

// check if any allow or deny rule contains 'auth()' invocation
let hasAuthRef = false;
for (const node of [...denies, ...allows]) {
for (const child of streamAllContents(node)) {
if (isInvocationExpr(child) && resolved(child.function).name === 'auth') {
hasAuthRef = true;
break;
}
}
if (hasAuthRef) {
break;
}
}
const hasAuthRef = [...denies, ...allows].some((rule) =>
streamAllContents(rule).some((child) => isAuthInvocation(child))
);

if (hasAuthRef) {
const userModel = model.$container.declarations.find(
Expand All @@ -365,47 +358,73 @@ export default class PolicyGenerator {
);
}

// r = <guard object>;
func.addStatements((writer) => {
writer.write('return ');
const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate');
const writeDenies = () => {
writer.conditionalWrite(denies.length > 1, '{ AND: [');
denies.forEach((expr, i) => {
writer.inlineBlock(() => {
writer.write('NOT: ');
exprWriter.write(expr);
});
writer.conditionalWrite(i !== denies.length - 1, ',');
const hasFieldAccess = [...denies, ...allows].some((rule) =>
streamAllContents(rule).some(
(child) =>
// this.???
isThisExpr(child) ||
// future().???
isFutureExpr(child) ||
// field reference
(isReferenceExpr(child) && isDataModelField(child.target.ref))
)
);

if (!hasFieldAccess) {
// none of the rules reference model fields, we can compile down to a plain boolean
// function in this case (so we can skip doing SQL queries when validating)
func.addStatements((writer) => {
const transformer = new TypeScriptExpressionTransformer(kind === 'postUpdate');
denies.forEach((rule) => {
writer.write(`if (${transformer.transform(rule, false)}) { return false; }`);
});
writer.conditionalWrite(denies.length > 1, ']}');
};

const writeAllows = () => {
writer.conditionalWrite(allows.length > 1, '{ OR: [');
allows.forEach((expr, i) => {
exprWriter.write(expr);
writer.conditionalWrite(i !== allows.length - 1, ',');
allows.forEach((rule) => {
writer.write(`if (${transformer.transform(rule, false)}) { return true; }`);
});
writer.conditionalWrite(allows.length > 1, ']}');
};

if (allows.length > 0 && denies.length > 0) {
writer.write('{ AND: [');
writeDenies();
writer.write(',');
writeAllows();
writer.write(']}');
} else if (denies.length > 0) {
writeDenies();
} else if (allows.length > 0) {
writeAllows();
} else {
// disallow any operation
writer.write(`{ ${GUARD_FIELD_NAME}: false }`);
}
writer.write(';');
});
writer.write('return false;');
});
} else {
func.addStatements((writer) => {
writer.write('return ');
const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate');
const writeDenies = () => {
writer.conditionalWrite(denies.length > 1, '{ AND: [');
denies.forEach((expr, i) => {
writer.inlineBlock(() => {
writer.write('NOT: ');
exprWriter.write(expr);
});
writer.conditionalWrite(i !== denies.length - 1, ',');
});
writer.conditionalWrite(denies.length > 1, ']}');
};

const writeAllows = () => {
writer.conditionalWrite(allows.length > 1, '{ OR: [');
allows.forEach((expr, i) => {
exprWriter.write(expr);
writer.conditionalWrite(i !== allows.length - 1, ',');
});
writer.conditionalWrite(allows.length > 1, ']}');
};

if (allows.length > 0 && denies.length > 0) {
writer.write('{ AND: [');
writeDenies();
writer.write(',');
writeAllows();
writer.write(']}');
} else if (denies.length > 0) {
writeDenies();
} else if (allows.length > 0) {
writeAllows();
} else {
// disallow any operation
writer.write(`{ ${GUARD_FIELD_NAME}: false }`);
}
writer.write(';');
});
}
return func;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ export default class TypeScriptExpressionTransformer {
normalizeUndefined
)}) ?? false)`;
} else {
return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`;
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
expr.right,
normalizeUndefined
)})`;
}
}
}
7 changes: 4 additions & 3 deletions packages/schema/src/plugins/access-policy/utils.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Expression, isInvocationExpr } from '@zenstackhq/language/ast';
import { isInvocationExpr } from '@zenstackhq/language/ast';
import { AstNode } from 'langium/lib/syntax-tree';
import { isFromStdlib } from '../../language-server/utils';

/**
* Returns if the given expression is a "future()" method call.
*/
export function isFutureExpr(expr: Expression) {
return !!(isInvocationExpr(expr) && expr.function.ref?.name === 'future' && isFromStdlib(expr.function.ref));
export function isFutureExpr(node: AstNode) {
return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref));
}
9 changes: 4 additions & 5 deletions packages/schema/src/utils/ast-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ import {
} from '@zenstackhq/language/ast';
import { PolicyOperationKind } from '@zenstackhq/runtime';
import { getLiteral } from '@zenstackhq/sdk';
import { AstNode, Mutable } from 'langium';
import { isFromStdlib } from '../language-server/utils';
import { getDocument, LangiumDocuments } from 'langium';
import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium';
import { URI, Utils } from 'vscode-uri';
import { isFromStdlib } from '../language-server/utils';

export function extractDataModelsWithAllowRules(model: Model): DataModel[] {
return model.declarations.filter(
Expand Down Expand Up @@ -163,8 +162,8 @@ export function getIdFields(dataModel: DataModel) {
return [];
}

export function isAuthInvocation(expr: Expression) {
return isInvocationExpr(expr) && expr.function.ref?.name === 'auth' && isFromStdlib(expr.function.ref);
export function isAuthInvocation(node: AstNode) {
return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref);
}

export function isEnumFieldReference(expr: Expression) {
Expand Down
71 changes: 71 additions & 0 deletions packages/schema/tests/plugins/policy.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import { loadSchema } from '@zenstackhq/testtools';

describe('Policy plugin tests', () => {
let origDir: string;

beforeEach(() => {
origDir = process.cwd();
});

afterEach(() => {
process.chdir(origDir);
});

it('short-circuit', async () => {
const model = `
model User {
id String @id @default(cuid())
value Int
}

model M {
id String @id @default(cuid())
value Int
@@allow('read', auth() != null)
@@allow('create', auth().value > 0)

@@allow('update', auth() != null)
@@deny('update', auth().value == null || auth().value <= 0)
}
`;

const { policy } = await loadSchema(model);

expect(policy.guard.m.read({ user: undefined })).toEqual(false);
expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(true);

expect(policy.guard.m.create({ user: undefined })).toEqual(false);
expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(false);
expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(false);
expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(true);

expect(policy.guard.m.update({ user: undefined })).toEqual(false);
expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(false);
expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(false);
expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(true);
});

it('no short-circuit', async () => {
const model = `
model User {
id String @id @default(cuid())
value Int
}

model M {
id String @id @default(cuid())
value Int
@@allow('read', auth() != null && value > 0)
}
`;

const { policy } = await loadSchema(model);

expect(policy.guard.m.read({ user: undefined })).toEqual(
expect.objectContaining({ AND: [{ zenstack_guard: false }, { value: { gt: 0 } }] })
);
expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(
expect.objectContaining({ AND: [{ zenstack_guard: true }, { value: { gt: 0 } }] })
);
});
});
Loading