diff --git a/package.json b/package.json index 3f8f59a91..64757ab2d 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/language/package.json b/packages/language/package.json index cb38faa65..37d6ff34b 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/next/package.json b/packages/next/package.json index 45467ec7d..aa47b06dd 100644 --- a/packages/next/package.json +++ b/packages/next/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/next", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "displayName": "ZenStack Next.js integration", "description": "ZenStack Next.js integration", "homepage": "https://zenstack.dev", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index f796dd14f..3744c94fb 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/react/package.json b/packages/plugins/react/package.json index 4c8e8f80a..21cd12088 100644 --- a/packages/plugins/react/package.json +++ b/packages/plugins/react/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/react", "displayName": "ZenStack plugin and runtime for ReactJS", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "ZenStack plugin and runtime for ReactJS", "main": "index.js", "repository": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 719c8c83b..45905e994 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index a902a5654..1b30261c4 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 4f244bef5..839ad15b9 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -321,6 +321,10 @@ export class PolicyUtil { * omitted. */ async postProcessForRead(entityData: any, model: string, args: any, operation: PolicyOperationKind) { + if (typeof entityData !== 'object' || !entityData) { + return; + } + const ids = this.getEntityIds(model, entityData); if (Object.keys(ids).length === 0) { return; @@ -739,6 +743,14 @@ export class PolicyUtil { operation: PolicyOperationKind, db: Record ) { + const guard = await this.getAuthGuard(model, operation); + const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model)); + + if (guard === true && !schema) { + // unconditionally allowed + return; + } + // DEBUG // this.logger.info(`Checking policy for ${model}#${JSON.stringify(filter)} for ${operation}`); @@ -750,13 +762,19 @@ export class PolicyUtil { await this.flattenGeneratedUniqueField(model, queryFilter); const count = (await db[model].count({ where: queryFilter })) as number; - const guard = await this.getAuthGuard(model, operation); + if (count === 0) { + // there's nothing to filter out + return; + } + + if (guard === false) { + // unconditionally denied + throw this.deniedByPolicy(model, operation, `${count} ${pluralize('entity', count)} failed policy check`); + } // build a query condition with policy injected const guardedQuery = { where: this.and(queryFilter, guard) }; - const schema = (operation === 'create' || operation === 'update') && (await this.getModelSchema(model)); - if (schema) { // we've got schemas, so have to fetch entities and validate them const entities = await db[model].findMany(guardedQuery); diff --git a/packages/schema/package.json b/packages/schema/package.json index b4b3806d9..b46808e18 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "A toolkit for building secure CRUD apps with Next.js + Typescript", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 31274dd13..ac9e26e92 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -9,6 +9,7 @@ import { isInvocationExpr, isMemberAccessExpr, isReferenceExpr, + isThisExpr, isUnaryExpr, MemberAccessExpr, Model, @@ -33,9 +34,10 @@ import path from 'path'; import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph'; import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; -import { getIdFields } from '../../utils/ast-utils'; +import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; import { ExpressionWriter } from './expression-writer'; +import TypeScriptExpressionTransformer from './typescript-expression-transformer'; import { isFutureExpr } from './utils'; import { ZodSchemaGenerator } from './zod-schema-generator'; @@ -332,18 +334,9 @@ export default class PolicyGenerator { .addBody(); // check if any allow or deny rule contains 'auth()' invocation - let hasAuthRef = false; - for (const node of [...denies, ...allows]) { - for (const child of streamAllContents(node)) { - if (isInvocationExpr(child) && resolved(child.function).name === 'auth') { - hasAuthRef = true; - break; - } - } - if (hasAuthRef) { - break; - } - } + const hasAuthRef = [...denies, ...allows].some((rule) => + streamAllContents(rule).some((child) => isAuthInvocation(child)) + ); if (hasAuthRef) { const userModel = model.$container.declarations.find( @@ -365,47 +358,73 @@ export default class PolicyGenerator { ); } - // r = ; - func.addStatements((writer) => { - writer.write('return '); - const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); - const writeDenies = () => { - writer.conditionalWrite(denies.length > 1, '{ AND: ['); - denies.forEach((expr, i) => { - writer.inlineBlock(() => { - writer.write('NOT: '); - exprWriter.write(expr); - }); - writer.conditionalWrite(i !== denies.length - 1, ','); + const hasFieldAccess = [...denies, ...allows].some((rule) => + streamAllContents(rule).some( + (child) => + // this.??? + isThisExpr(child) || + // future().??? + isFutureExpr(child) || + // field reference + (isReferenceExpr(child) && isDataModelField(child.target.ref)) + ) + ); + + if (!hasFieldAccess) { + // none of the rules reference model fields, we can compile down to a plain boolean + // function in this case (so we can skip doing SQL queries when validating) + func.addStatements((writer) => { + const transformer = new TypeScriptExpressionTransformer(kind === 'postUpdate'); + denies.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return false; }`); }); - writer.conditionalWrite(denies.length > 1, ']}'); - }; - - const writeAllows = () => { - writer.conditionalWrite(allows.length > 1, '{ OR: ['); - allows.forEach((expr, i) => { - exprWriter.write(expr); - writer.conditionalWrite(i !== allows.length - 1, ','); + allows.forEach((rule) => { + writer.write(`if (${transformer.transform(rule, false)}) { return true; }`); }); - writer.conditionalWrite(allows.length > 1, ']}'); - }; - - if (allows.length > 0 && denies.length > 0) { - writer.write('{ AND: ['); - writeDenies(); - writer.write(','); - writeAllows(); - writer.write(']}'); - } else if (denies.length > 0) { - writeDenies(); - } else if (allows.length > 0) { - writeAllows(); - } else { - // disallow any operation - writer.write(`{ ${GUARD_FIELD_NAME}: false }`); - } - writer.write(';'); - }); + writer.write('return false;'); + }); + } else { + func.addStatements((writer) => { + writer.write('return '); + const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); + const writeDenies = () => { + writer.conditionalWrite(denies.length > 1, '{ AND: ['); + denies.forEach((expr, i) => { + writer.inlineBlock(() => { + writer.write('NOT: '); + exprWriter.write(expr); + }); + writer.conditionalWrite(i !== denies.length - 1, ','); + }); + writer.conditionalWrite(denies.length > 1, ']}'); + }; + + const writeAllows = () => { + writer.conditionalWrite(allows.length > 1, '{ OR: ['); + allows.forEach((expr, i) => { + exprWriter.write(expr); + writer.conditionalWrite(i !== allows.length - 1, ','); + }); + writer.conditionalWrite(allows.length > 1, ']}'); + }; + + if (allows.length > 0 && denies.length > 0) { + writer.write('{ AND: ['); + writeDenies(); + writer.write(','); + writeAllows(); + writer.write(']}'); + } else if (denies.length > 0) { + writeDenies(); + } else if (allows.length > 0) { + writeAllows(); + } else { + // disallow any operation + writer.write(`{ ${GUARD_FIELD_NAME}: false }`); + } + writer.write(';'); + }); + } return func; } } diff --git a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts index 98dde9004..61d4b6f9c 100644 --- a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts +++ b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts @@ -202,7 +202,10 @@ export default class TypeScriptExpressionTransformer { normalizeUndefined )}) ?? false)`; } else { - return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`; + return `(${this.transform(expr.left, normalizeUndefined)} ${expr.operator} ${this.transform( + expr.right, + normalizeUndefined + )})`; } } } diff --git a/packages/schema/src/plugins/access-policy/utils.ts b/packages/schema/src/plugins/access-policy/utils.ts index 741d686d9..816386a6e 100644 --- a/packages/schema/src/plugins/access-policy/utils.ts +++ b/packages/schema/src/plugins/access-policy/utils.ts @@ -1,9 +1,10 @@ -import { Expression, isInvocationExpr } from '@zenstackhq/language/ast'; +import { isInvocationExpr } from '@zenstackhq/language/ast'; +import { AstNode } from 'langium/lib/syntax-tree'; import { isFromStdlib } from '../../language-server/utils'; /** * Returns if the given expression is a "future()" method call. */ -export function isFutureExpr(expr: Expression) { - return !!(isInvocationExpr(expr) && expr.function.ref?.name === 'future' && isFromStdlib(expr.function.ref)); +export function isFutureExpr(node: AstNode) { + return !!(isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)); } diff --git a/packages/schema/src/utils/ast-utils.ts b/packages/schema/src/utils/ast-utils.ts index ba951a1ea..ab26509eb 100644 --- a/packages/schema/src/utils/ast-utils.ts +++ b/packages/schema/src/utils/ast-utils.ts @@ -17,10 +17,9 @@ import { } from '@zenstackhq/language/ast'; import { PolicyOperationKind } from '@zenstackhq/runtime'; import { getLiteral } from '@zenstackhq/sdk'; -import { AstNode, Mutable } from 'langium'; -import { isFromStdlib } from '../language-server/utils'; -import { getDocument, LangiumDocuments } from 'langium'; +import { AstNode, getDocument, LangiumDocuments, Mutable } from 'langium'; import { URI, Utils } from 'vscode-uri'; +import { isFromStdlib } from '../language-server/utils'; export function extractDataModelsWithAllowRules(model: Model): DataModel[] { return model.declarations.filter( @@ -163,8 +162,8 @@ export function getIdFields(dataModel: DataModel) { return []; } -export function isAuthInvocation(expr: Expression) { - return isInvocationExpr(expr) && expr.function.ref?.name === 'auth' && isFromStdlib(expr.function.ref); +export function isAuthInvocation(node: AstNode) { + return isInvocationExpr(node) && node.function.ref?.name === 'auth' && isFromStdlib(node.function.ref); } export function isEnumFieldReference(expr: Expression) { diff --git a/packages/schema/tests/plugins/policy.test.ts b/packages/schema/tests/plugins/policy.test.ts new file mode 100644 index 000000000..6666abb5e --- /dev/null +++ b/packages/schema/tests/plugins/policy.test.ts @@ -0,0 +1,71 @@ +import { loadSchema } from '@zenstackhq/testtools'; + +describe('Policy plugin tests', () => { + let origDir: string; + + beforeEach(() => { + origDir = process.cwd(); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('short-circuit', async () => { + const model = ` +model User { + id String @id @default(cuid()) + value Int +} + +model M { + id String @id @default(cuid()) + value Int + @@allow('read', auth() != null) + @@allow('create', auth().value > 0) + + @@allow('update', auth() != null) + @@deny('update', auth().value == null || auth().value <= 0) +} + `; + + const { policy } = await loadSchema(model); + + expect(policy.guard.m.read({ user: undefined })).toEqual(false); + expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(true); + + expect(policy.guard.m.create({ user: undefined })).toEqual(false); + expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(false); + expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(false); + expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(true); + + expect(policy.guard.m.update({ user: undefined })).toEqual(false); + expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(false); + expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(false); + expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(true); + }); + + it('no short-circuit', async () => { + const model = ` +model User { + id String @id @default(cuid()) + value Int +} + +model M { + id String @id @default(cuid()) + value Int + @@allow('read', auth() != null && value > 0) +} + `; + + const { policy } = await loadSchema(model); + + expect(policy.guard.m.read({ user: undefined })).toEqual( + expect.objectContaining({ AND: [{ zenstack_guard: false }, { value: { gt: 0 } }] }) + ); + expect(policy.guard.m.read({ user: { id: '1' } })).toEqual( + expect.objectContaining({ AND: [{ zenstack_guard: true }, { value: { gt: 0 } }] }) + ); + }); +}); diff --git a/packages/sdk/package.json b/packages/sdk/package.json index 1fe8e20b4..270d813e3 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/server/package.json b/packages/server/package.json index 1914bd143..a1a9b09bc 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index b92305198..fefe895ab 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "description": "ZenStack Test Tools", "main": "index.js", "publishConfig": { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 82f91b06d..2f278eabb 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -150,6 +150,8 @@ export async function loadSchema( withOmit: () => withOmit(prisma, modelMeta), withPassword: () => withPassword(prisma, modelMeta), withPresets: (user?: AuthUser) => withPresets(prisma, { user }, policy, modelMeta), + policy, + modelMeta, zodSchemas, }; } catch (err) { diff --git a/tests/integration/test-run/package-lock.json b/tests/integration/test-run/package-lock.json index d93b7a941..9e463e65e 100644 --- a/tests/integration/test-run/package-lock.json +++ b/tests/integration/test-run/package-lock.json @@ -126,7 +126,7 @@ }, "../../../packages/runtime/dist": { "name": "@zenstackhq/runtime", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "license": "MIT", "dependencies": { "@paralleldrive/cuid2": "^2.2.0", @@ -158,7 +158,7 @@ }, "../../../packages/schema/dist": { "name": "zenstack", - "version": "1.0.0-alpha.112", + "version": "1.0.0-alpha.113", "hasInstallScript": true, "license": "MIT", "dependencies": {