Skip to content

Commit a54eba4

Browse files
authored
fix: short-circuit post-read check when policy rules don't depend on model fields (#376)
1 parent 4bf1304 commit a54eba4

File tree

19 files changed

+190
-77
lines changed

19 files changed

+190
-77
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "zenstack-monorepo",
3-
"version": "1.0.0-alpha.112",
3+
"version": "1.0.0-alpha.113",
44
"description": "",
55
"scripts": {
66
"build": "pnpm -r build",

packages/language/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@zenstackhq/language",
3-
"version": "1.0.0-alpha.112",
3+
"version": "1.0.0-alpha.113",
44
"displayName": "ZenStack modeling language compiler",
55
"description": "ZenStack modeling language compiler",
66
"homepage": "https://zenstack.dev",

packages/next/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@zenstackhq/next",
3-
"version": "1.0.0-alpha.112",
3+
"version": "1.0.0-alpha.113",
44
"displayName": "ZenStack Next.js integration",
55
"description": "ZenStack Next.js integration",
66
"homepage": "https://zenstack.dev",

packages/plugins/openapi/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/openapi",
33
"displayName": "ZenStack Plugin and Runtime for OpenAPI",
4-
"version": "1.0.0-alpha.112",
4+
"version": "1.0.0-alpha.113",
55
"description": "ZenStack plugin and runtime supporting OpenAPI",
66
"main": "index.js",
77
"repository": {

packages/plugins/react/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/react",
33
"displayName": "ZenStack plugin and runtime for ReactJS",
4-
"version": "1.0.0-alpha.112",
4+
"version": "1.0.0-alpha.113",
55
"description": "ZenStack plugin and runtime for ReactJS",
66
"main": "index.js",
77
"repository": {

packages/plugins/trpc/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/trpc",
33
"displayName": "ZenStack plugin for tRPC",
4-
"version": "1.0.0-alpha.112",
4+
"version": "1.0.0-alpha.113",
55
"description": "ZenStack plugin for tRPC",
66
"main": "index.js",
77
"repository": {

packages/runtime/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@zenstackhq/runtime",
33
"displayName": "ZenStack Runtime Library",
4-
"version": "1.0.0-alpha.112",
4+
"version": "1.0.0-alpha.113",
55
"description": "Runtime of ZenStack for both client-side and server-side environments.",
66
"repository": {
77
"type": "git",

packages/runtime/src/enhancements/policy/policy-utils.ts

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,10 @@ export class PolicyUtil {
321321
* omitted.
322322
*/
323323
async postProcessForRead(entityData: any, model: string, args: any, operation: PolicyOperationKind) {
324+
if (typeof entityData !== 'object' || !entityData) {
325+
return;
326+
}
327+
324328
const ids = this.getEntityIds(model, entityData);
325329
if (Object.keys(ids).length === 0) {
326330
return;
@@ -739,6 +743,14 @@ export class PolicyUtil {
739743
operation: PolicyOperationKind,
740744
db: Record<string, DbOperations>
741745
) {
746+
const guard = await this.getAuthGuard(model, operation);
747+
const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model));
748+
749+
if (guard === true && !schema) {
750+
// unconditionally allowed
751+
return;
752+
}
753+
742754
// DEBUG
743755
// this.logger.info(`Checking policy for ${model}#${JSON.stringify(filter)} for ${operation}`);
744756

@@ -750,13 +762,19 @@ export class PolicyUtil {
750762
await this.flattenGeneratedUniqueField(model, queryFilter);
751763

752764
const count = (await db[model].count({ where: queryFilter })) as number;
753-
const guard = await this.getAuthGuard(model, operation);
765+
if (count === 0) {
766+
// there's nothing to filter out
767+
return;
768+
}
769+
770+
if (guard === false) {
771+
// unconditionally denied
772+
throw this.deniedByPolicy(model, operation, `${count} ${pluralize('entity', count)} failed policy check`);
773+
}
754774

755775
// build a query condition with policy injected
756776
const guardedQuery = { where: this.and(queryFilter, guard) };
757777

758-
const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model));
759-
760778
if (schema) {
761779
// we've got schemas, so have to fetch entities and validate them
762780
const entities = await db[model].findMany(guardedQuery);

packages/schema/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"publisher": "zenstack",
44
"displayName": "ZenStack Language Tools",
55
"description": "A toolkit for building secure CRUD apps with Next.js + Typescript",
6-
"version": "1.0.0-alpha.112",
6+
"version": "1.0.0-alpha.113",
77
"author": {
88
"name": "ZenStack Team"
99
},

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 71 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
isInvocationExpr,
1010
isMemberAccessExpr,
1111
isReferenceExpr,
12+
isThisExpr,
1213
isUnaryExpr,
1314
MemberAccessExpr,
1415
Model,
@@ -33,9 +34,10 @@ import path from 'path';
3334
import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph';
3435
import { name } from '.';
3536
import { isFromStdlib } from '../../language-server/utils';
36-
import { getIdFields } from '../../utils/ast-utils';
37+
import { getIdFields, isAuthInvocation } from '../../utils/ast-utils';
3738
import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils';
3839
import { ExpressionWriter } from './expression-writer';
40+
import TypeScriptExpressionTransformer from './typescript-expression-transformer';
3941
import { isFutureExpr } from './utils';
4042
import { ZodSchemaGenerator } from './zod-schema-generator';
4143

@@ -332,18 +334,9 @@ export default class PolicyGenerator {
332334
.addBody();
333335

334336
// check if any allow or deny rule contains 'auth()' invocation
335-
let hasAuthRef = false;
336-
for (const node of [...denies, ...allows]) {
337-
for (const child of streamAllContents(node)) {
338-
if (isInvocationExpr(child) && resolved(child.function).name === 'auth') {
339-
hasAuthRef = true;
340-
break;
341-
}
342-
}
343-
if (hasAuthRef) {
344-
break;
345-
}
346-
}
337+
const hasAuthRef = [...denies, ...allows].some((rule) =>
338+
streamAllContents(rule).some((child) => isAuthInvocation(child))
339+
);
347340

348341
if (hasAuthRef) {
349342
const userModel = model.$container.declarations.find(
@@ -365,47 +358,73 @@ export default class PolicyGenerator {
365358
);
366359
}
367360

368-
// r = <guard object>;
369-
func.addStatements((writer) => {
370-
writer.write('return ');
371-
const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate');
372-
const writeDenies = () => {
373-
writer.conditionalWrite(denies.length > 1, '{ AND: [');
374-
denies.forEach((expr, i) => {
375-
writer.inlineBlock(() => {
376-
writer.write('NOT: ');
377-
exprWriter.write(expr);
378-
});
379-
writer.conditionalWrite(i !== denies.length - 1, ',');
361+
const hasFieldAccess = [...denies, ...allows].some((rule) =>
362+
streamAllContents(rule).some(
363+
(child) =>
364+
// this.???
365+
isThisExpr(child) ||
366+
// future().???
367+
isFutureExpr(child) ||
368+
// field reference
369+
(isReferenceExpr(child) && isDataModelField(child.target.ref))
370+
)
371+
);
372+
373+
if (!hasFieldAccess) {
374+
// none of the rules reference model fields, we can compile down to a plain boolean
375+
// function in this case (so we can skip doing SQL queries when validating)
376+
func.addStatements((writer) => {
377+
const transformer = new TypeScriptExpressionTransformer(kind === 'postUpdate');
378+
denies.forEach((rule) => {
379+
writer.write(`if (${transformer.transform(rule, false)}) { return false; }`);
380380
});
381-
writer.conditionalWrite(denies.length > 1, ']}');
382-
};
383-
384-
const writeAllows = () => {
385-
writer.conditionalWrite(allows.length > 1, '{ OR: [');
386-
allows.forEach((expr, i) => {
387-
exprWriter.write(expr);
388-
writer.conditionalWrite(i !== allows.length - 1, ',');
381+
allows.forEach((rule) => {
382+
writer.write(`if (${transformer.transform(rule, false)}) { return true; }`);
389383
});
390-
writer.conditionalWrite(allows.length > 1, ']}');
391-
};
392-
393-
if (allows.length > 0 && denies.length > 0) {
394-
writer.write('{ AND: [');
395-
writeDenies();
396-
writer.write(',');
397-
writeAllows();
398-
writer.write(']}');
399-
} else if (denies.length > 0) {
400-
writeDenies();
401-
} else if (allows.length > 0) {
402-
writeAllows();
403-
} else {
404-
// disallow any operation
405-
writer.write(`{ ${GUARD_FIELD_NAME}: false }`);
406-
}
407-
writer.write(';');
408-
});
384+
writer.write('return false;');
385+
});
386+
} else {
387+
func.addStatements((writer) => {
388+
writer.write('return ');
389+
const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate');
390+
const writeDenies = () => {
391+
writer.conditionalWrite(denies.length > 1, '{ AND: [');
392+
denies.forEach((expr, i) => {
393+
writer.inlineBlock(() => {
394+
writer.write('NOT: ');
395+
exprWriter.write(expr);
396+
});
397+
writer.conditionalWrite(i !== denies.length - 1, ',');
398+
});
399+
writer.conditionalWrite(denies.length > 1, ']}');
400+
};
401+
402+
const writeAllows = () => {
403+
writer.conditionalWrite(allows.length > 1, '{ OR: [');
404+
allows.forEach((expr, i) => {
405+
exprWriter.write(expr);
406+
writer.conditionalWrite(i !== allows.length - 1, ',');
407+
});
408+
writer.conditionalWrite(allows.length > 1, ']}');
409+
};
410+
411+
if (allows.length > 0 && denies.length > 0) {
412+
writer.write('{ AND: [');
413+
writeDenies();
414+
writer.write(',');
415+
writeAllows();
416+
writer.write(']}');
417+
} else if (denies.length > 0) {
418+
writeDenies();
419+
} else if (allows.length > 0) {
420+
writeAllows();
421+
} else {
422+
// disallow any operation
423+
writer.write(`{ ${GUARD_FIELD_NAME}: false }`);
424+
}
425+
writer.write(';');
426+
});
427+
}
409428
return func;
410429
}
411430
}

packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,10 @@ export default class TypeScriptExpressionTransformer {
202202
normalizeUndefined
203203
)}) ?? false)`;
204204
} else {
205-
return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`;
205+
return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform(
206+
expr.right,
207+
normalizeUndefined
208+
)})`;
206209
}
207210
}
208211
}
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import { Expression, isInvocationExpr } from '@zenstackhq/language/ast';
1+
import { isInvocationExpr } from '@zenstackhq/language/ast';
2+
import { AstNode } from 'langium/lib/syntax-tree';
23
import { isFromStdlib } from '../../language-server/utils';
34

45
/**
56
* Returns if the given expression is a "future()" method call.
67
*/
7-
export function isFutureExpr(expr: Expression) {
8-
return !!(isInvocationExpr(expr) && expr.function.ref?.name === 'future' && isFromStdlib(expr.function.ref));
8+
export function isFutureExpr(node: AstNode) {
9+
return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref));
910
}

packages/schema/src/utils/ast-utils.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,9 @@ import {
1717
} from '@zenstackhq/language/ast';
1818
import { PolicyOperationKind } from '@zenstackhq/runtime';
1919
import { getLiteral } from '@zenstackhq/sdk';
20-
import { AstNode, Mutable } from 'langium';
21-
import { isFromStdlib } from '../language-server/utils';
22-
import { getDocument, LangiumDocuments } from 'langium';
20+
import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium';
2321
import { URI, Utils } from 'vscode-uri';
22+
import { isFromStdlib } from '../language-server/utils';
2423

2524
export function extractDataModelsWithAllowRules(model: Model): DataModel[] {
2625
return model.declarations.filter(
@@ -163,8 +162,8 @@ export function getIdFields(dataModel: DataModel) {
163162
return [];
164163
}
165164

166-
export function isAuthInvocation(expr: Expression) {
167-
return isInvocationExpr(expr) && expr.function.ref?.name === 'auth' && isFromStdlib(expr.function.ref);
165+
export function isAuthInvocation(node: AstNode) {
166+
return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref);
168167
}
169168

170169
export function isEnumFieldReference(expr: Expression) {
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('Policy plugin tests', () => {
4+
let origDir: string;
5+
6+
beforeEach(() => {
7+
origDir = process.cwd();
8+
});
9+
10+
afterEach(() => {
11+
process.chdir(origDir);
12+
});
13+
14+
it('short-circuit', async () => {
15+
const model = `
16+
model User {
17+
id String @id @default(cuid())
18+
value Int
19+
}
20+
21+
model M {
22+
id String @id @default(cuid())
23+
value Int
24+
@@allow('read', auth() != null)
25+
@@allow('create', auth().value > 0)
26+
27+
@@allow('update', auth() != null)
28+
@@deny('update', auth().value == null || auth().value <= 0)
29+
}
30+
`;
31+
32+
const { policy } = await loadSchema(model);
33+
34+
expect(policy.guard.m.read({ user: undefined })).toEqual(false);
35+
expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(true);
36+
37+
expect(policy.guard.m.create({ user: undefined })).toEqual(false);
38+
expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(false);
39+
expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(false);
40+
expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(true);
41+
42+
expect(policy.guard.m.update({ user: undefined })).toEqual(false);
43+
expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(false);
44+
expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(false);
45+
expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(true);
46+
});
47+
48+
it('no short-circuit', async () => {
49+
const model = `
50+
model User {
51+
id String @id @default(cuid())
52+
value Int
53+
}
54+
55+
model M {
56+
id String @id @default(cuid())
57+
value Int
58+
@@allow('read', auth() != null && value > 0)
59+
}
60+
`;
61+
62+
const { policy } = await loadSchema(model);
63+
64+
expect(policy.guard.m.read({ user: undefined })).toEqual(
65+
expect.objectContaining({ AND: [{ zenstack_guard: false }, { value: { gt: 0 } }] })
66+
);
67+
expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(
68+
expect.objectContaining({ AND: [{ zenstack_guard: true }, { value: { gt: 0 } }] })
69+
);
70+
});
71+
});

0 commit comments

Comments
 (0)