diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 082578ba2..1a572315f 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -65,7 +65,7 @@ export class PolicyProxyHandler implements Pr } args = this.utils.clone(args); - if (!(await this.utils.injectForRead(this.model, args))) { + if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return null; } @@ -86,7 +86,7 @@ export class PolicyProxyHandler implements Pr } args = this.utils.clone(args); - if (!(await this.utils.injectForRead(this.model, args))) { + if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { throw this.utils.notFound(this.model); } @@ -100,7 +100,7 @@ export class PolicyProxyHandler implements Pr async findFirst(args: any) { args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.model, args))) { + if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return null; } @@ -114,7 +114,7 @@ export class PolicyProxyHandler implements Pr async findFirstOrThrow(args: any) { args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.model, args))) { + if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { throw this.utils.notFound(this.model); } @@ -128,7 +128,7 @@ export class PolicyProxyHandler implements Pr async findMany(args: any) { args = args ? this.utils.clone(args) : {}; - if (!(await this.utils.injectForRead(this.model, args))) { + if (!(await this.utils.injectForRead(this.prisma, this.model, args))) { return []; } @@ -152,7 +152,7 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.utils.tryReject(this.model, 'create'); + await this.utils.tryReject(this.prisma, this.model, 'create'); const origArgs = args; args = this.utils.clone(args); @@ -404,7 +404,7 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - this.utils.tryReject(this.model, 'create'); + this.utils.tryReject(this.prisma, this.model, 'create'); args = this.utils.clone(args); @@ -635,7 +635,7 @@ export class PolicyProxyHandler implements Pr } if (thisModelUpdate) { - this.utils.tryReject(this.model, 'update'); + this.utils.tryReject(db, this.model, 'update'); // check pre-update guard await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); @@ -660,7 +660,7 @@ export class PolicyProxyHandler implements Pr updateMany: async (model, args, context) => { // injects auth guard into where clause - await this.utils.injectAuthGuard(args, model, 'update'); + await this.utils.injectAuthGuard(db, args, model, 'update'); // prepare for post-update check if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { @@ -671,7 +671,7 @@ export class PolicyProxyHandler implements Pr } const reversedQuery = await this.utils.buildReversedQuery(context); const currentSetQuery = { select, where: reversedQuery }; - await this.utils.injectAuthGuard(currentSetQuery, model, 'read'); + await this.utils.injectAuthGuard(db, currentSetQuery, model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(currentSetQuery)}`); @@ -794,7 +794,7 @@ export class PolicyProxyHandler implements Pr deleteMany: async (model, args, context) => { // inject delete guard - const guard = await this.utils.getAuthGuard(model, 'delete'); + const guard = await this.utils.getAuthGuard(db, model, 'delete'); context.parent.deleteMany = this.utils.and(args, guard); }, }); @@ -822,10 +822,10 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.utils.tryReject(this.model, 'update'); + await this.utils.tryReject(this.prisma, this.model, 'update'); args = this.utils.clone(args); - await this.utils.injectAuthGuard(args, this.model, 'update'); + await this.utils.injectAuthGuard(this.prisma, args, this.model, 'update'); if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { // use a transaction to do post-update checks @@ -838,7 +838,7 @@ export class PolicyProxyHandler implements Pr select = { ...select, ...preValueSelect }; } const currentSetQuery = { select, where: args.where }; - await this.utils.injectAuthGuard(currentSetQuery, this.model, 'read'); + await this.utils.injectAuthGuard(tx, currentSetQuery, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); @@ -885,8 +885,8 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'update field is required in query argument'); } - await this.utils.tryReject(this.model, 'create'); - await this.utils.tryReject(this.model, 'update'); + await this.utils.tryReject(this.prisma, this.model, 'create'); + await this.utils.tryReject(this.prisma, this.model, 'update'); // We can call the native "upsert" because we can't tell if an entity was created or updated // for doing post-write check accordingly. Instead, decompose it into create or update. @@ -930,7 +930,7 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - await this.utils.tryReject(this.model, 'delete'); + await this.utils.tryReject(this.prisma, this.model, 'delete'); const { result, error } = await this.transaction(async (tx) => { // do a read-back before delete @@ -961,11 +961,11 @@ export class PolicyProxyHandler implements Pr } async deleteMany(args: any) { - await this.utils.tryReject(this.model, 'delete'); + await this.utils.tryReject(this.prisma, this.model, 'delete'); // inject policy conditions args = args ?? {}; - await this.utils.injectAuthGuard(args, this.model, 'delete'); + await this.utils.injectAuthGuard(this.prisma, args, this.model, 'delete'); // conduct the deletion if (this.shouldLogQuery) { @@ -984,7 +984,7 @@ export class PolicyProxyHandler implements Pr } // inject policy conditions - await this.utils.injectAuthGuard(args, this.model, 'read'); + await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); @@ -998,7 +998,7 @@ export class PolicyProxyHandler implements Pr } // inject policy conditions - await this.utils.injectAuthGuard(args, this.model, 'read'); + await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); @@ -1009,7 +1009,7 @@ export class PolicyProxyHandler implements Pr async count(args: any) { // inject policy conditions args = args ?? {}; - await this.utils.injectAuthGuard(args, this.model, 'read'); + await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index de1fc1f1f..cedadb5cd 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -162,7 +162,12 @@ export class PolicyUtil { * @returns true if operation is unconditionally allowed, false if unconditionally denied, * otherwise returns a guard object */ - getAuthGuard(model: string, operation: PolicyOperationKind, preValue?: any): object { + getAuthGuard( + db: Record, + model: string, + operation: PolicyOperationKind, + preValue?: any + ): object { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); @@ -176,7 +181,7 @@ export class PolicyUtil { if (!provider) { throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`); } - const r = provider({ user: this.user, preValue }); + const r = provider({ user: this.user, preValue }, db); return this.reduce(r); } @@ -219,8 +224,8 @@ export class PolicyUtil { /** * Injects model auth guard as where clause. */ - async injectAuthGuard(args: any, model: string, operation: PolicyOperationKind) { - const guard = this.getAuthGuard(model, operation); + async injectAuthGuard(db: Record, args: any, model: string, operation: PolicyOperationKind) { + const guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { args.where = this.makeFalse(); return false; @@ -230,14 +235,19 @@ export class PolicyUtil { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - await this.injectGuardForRelationFields(model, args.where, operation); + await this.injectGuardForRelationFields(db, model, args.where, operation); } args.where = this.and(args.where, guard); return true; } - private async injectGuardForRelationFields(model: string, payload: any, operation: PolicyOperationKind) { + private async injectGuardForRelationFields( + db: Record, + model: string, + payload: any, + operation: PolicyOperationKind + ) { for (const [field, subPayload] of Object.entries(payload)) { if (!subPayload) { continue; @@ -249,26 +259,27 @@ export class PolicyUtil { } if (fieldInfo.isArray) { - await this.injectGuardForToManyField(fieldInfo, subPayload, operation); + await this.injectGuardForToManyField(db, fieldInfo, subPayload, operation); } else { - await this.injectGuardForToOneField(fieldInfo, subPayload, operation); + await this.injectGuardForToOneField(db, fieldInfo, subPayload, operation); } } } private async injectGuardForToManyField( + db: Record, fieldInfo: FieldInfo, payload: { some?: any; every?: any; none?: any }, operation: PolicyOperationKind ) { - const guard = this.getAuthGuard(fieldInfo.type, operation); + const guard = this.getAuthGuard(db, fieldInfo.type, operation); if (payload.some) { - await this.injectGuardForRelationFields(fieldInfo.type, payload.some, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload.some, operation); // turn "some" into: { some: { AND: [guard, payload.some] } } payload.some = this.and(payload.some, guard); } if (payload.none) { - await this.injectGuardForRelationFields(fieldInfo.type, payload.none, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload.none, operation); // turn none into: { none: { AND: [guard, payload.none] } } payload.none = this.and(payload.none, guard); } @@ -278,7 +289,7 @@ export class PolicyUtil { // ignore empty every clause Object.keys(payload.every).length > 0 ) { - await this.injectGuardForRelationFields(fieldInfo.type, payload.every, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload.every, operation); // turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } } if (!payload.none) { @@ -290,25 +301,26 @@ export class PolicyUtil { } private async injectGuardForToOneField( + db: Record, fieldInfo: FieldInfo, payload: { is?: any; isNot?: any } & Record, operation: PolicyOperationKind ) { - const guard = this.getAuthGuard(fieldInfo.type, operation); + const guard = this.getAuthGuard(db, fieldInfo.type, operation); if (payload.is || payload.isNot) { if (payload.is) { - await this.injectGuardForRelationFields(fieldInfo.type, payload.is, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload.is, operation); // turn "is" into: { is: { AND: [ originalIs, guard ] } payload.is = this.and(payload.is, guard); } if (payload.isNot) { - await this.injectGuardForRelationFields(fieldInfo.type, payload.isNot, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload.isNot, operation); // turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } } payload.isNot = this.and(payload.isNot, this.not(guard)); delete payload.isNot; } } else { - await this.injectGuardForRelationFields(fieldInfo.type, payload, operation); + await this.injectGuardForRelationFields(db, fieldInfo.type, payload, operation); // turn direct conditions into: { is: { AND: [ originalConditions, guard ] } } const combined = this.and(deepcopy(payload), guard); Object.keys(payload).forEach((key) => delete payload[key]); @@ -319,9 +331,9 @@ export class PolicyUtil { /** * Injects auth guard for read operations. */ - async injectForRead(model: string, args: any) { + async injectForRead(db: Record, model: string, args: any) { const injected: any = {}; - if (!(await this.injectAuthGuard(injected, model, 'read'))) { + if (!(await this.injectAuthGuard(db, injected, model, 'read'))) { return false; } @@ -329,7 +341,7 @@ export class PolicyUtil { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - await this.injectGuardForRelationFields(model, args.where, 'read'); + await this.injectGuardForRelationFields(db, model, args.where, 'read'); } if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) { @@ -338,7 +350,7 @@ export class PolicyUtil { } // recursively inject read guard conditions into nested select, include, and _count - const hoistedConditions = await this.injectNestedReadConditions(model, args); + const hoistedConditions = await this.injectNestedReadConditions(db, model, args); // the injection process may generate conditions that need to be hoisted to the toplevel, // if so, merge it with the existing where @@ -429,7 +441,11 @@ export class PolicyUtil { return result; } - private async injectNestedReadConditions(model: string, args: any): Promise { + private async injectNestedReadConditions( + db: Record, + model: string, + args: any + ): Promise { const injectTarget = args.select ?? args.include; if (!injectTarget) { return []; @@ -462,7 +478,7 @@ export class PolicyUtil { continue; } // inject into the "where" clause inside select - await this.injectAuthGuard(injectTarget._count.select[field], fieldInfo.type, 'read'); + await this.injectAuthGuard(db, injectTarget._count.select[field], fieldInfo.type, 'read'); } } @@ -488,19 +504,19 @@ export class PolicyUtil { injectTarget[field] = {}; } // inject extra condition for to-many or nullable to-one relation - await this.injectAuthGuard(injectTarget[field], fieldInfo.type, 'read'); + await this.injectAuthGuard(db, injectTarget[field], fieldInfo.type, 'read'); // recurse - const subHoisted = await this.injectNestedReadConditions(fieldInfo.type, injectTarget[field]); + const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); if (subHoisted.length > 0) { // we can convert it to a where at this level injectTarget[field].where = this.and(injectTarget[field].where, ...subHoisted); } } else { // hoist non-nullable to-one filter to the parent level - hoisted = this.getAuthGuard(fieldInfo.type, 'read'); + hoisted = this.getAuthGuard(db, fieldInfo.type, 'read'); // recurse - const subHoisted = await this.injectNestedReadConditions(fieldInfo.type, injectTarget[field]); + const subHoisted = await this.injectNestedReadConditions(db, fieldInfo.type, injectTarget[field]); if (subHoisted.length > 0) { hoisted = this.and(hoisted, ...subHoisted); } @@ -525,7 +541,7 @@ export class PolicyUtil { db: Record, preValue?: any ) { - const guard = this.getAuthGuard(model, operation, preValue); + const guard = this.getAuthGuard(db, model, operation, preValue); if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation, `entity ${formatObject(uniqueFilter)} failed policy check`); } @@ -581,8 +597,8 @@ export class PolicyUtil { /** * Tries rejecting a request based on static "false" policy. */ - tryReject(model: string, operation: PolicyOperationKind) { - const guard = this.getAuthGuard(model, operation); + tryReject(db: Record, model: string, operation: PolicyOperationKind) { + const guard = this.getAuthGuard(db, model, operation); if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation); } @@ -633,7 +649,7 @@ export class PolicyUtil { CrudFailureReason.RESULT_NOT_READABLE ); - const injectResult = await this.injectForRead(model, readArgs); + const injectResult = await this.injectForRead(db, model, readArgs); if (!injectResult) { return { error, result: undefined }; } diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 3de460126..72ea092d0 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; -import { FieldInfo, PolicyOperationKind, QueryContext } from '../types'; +import type { DbOperations, FieldInfo, PolicyOperationKind, QueryContext } from '../types'; /** * Metadata for a model-level unique constraint @@ -19,7 +19,7 @@ export type ModelMeta = { /** * Function for getting policy guard with a given context */ -export type PolicyFunc = (context: QueryContext) => object; +export type PolicyFunc = (context: QueryContext, db: Record) => object; /** * Function for getting policy guard with a given context diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index ced5bc699..b796fd0b2 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -1,3 +1,5 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ + /** * Weakly-typed database access methods */ @@ -17,6 +19,7 @@ export interface DbOperations { aggregate(args: unknown): Promise; groupBy(args: unknown): Promise; count(args?: unknown): Promise; + fields: Record; } /** @@ -43,6 +46,9 @@ export type QueryContext = { */ user?: AuthUser; + /** + * Pre-update value of the entity + */ // eslint-disable-next-line @typescript-eslint/no-explicit-any preValue?: any; }; diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 3c7fbdd1d..d27b9ee55 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -4,7 +4,6 @@ import { Expression, InvocationExpr, isDataModel, - isDataModelField, isEnumField, isMemberAccessExpr, isReferenceExpr, @@ -14,7 +13,14 @@ import { ReferenceExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { ExpressionContext, getFunctionExpressionContext, getLiteral, PluginError } from '@zenstackhq/sdk'; +import { + ExpressionContext, + getFunctionExpressionContext, + getLiteral, + isDataModelFieldReference, + PluginError, +} from '@zenstackhq/sdk'; +import { lowerCaseFirst } from 'lower-case-first'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '.'; import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; @@ -191,6 +197,19 @@ export class ExpressionWriter { }, 'has' ); + } else if ( + isDataModelFieldReference(expr.left) && + isDataModelFieldReference(expr.right) && + expr.left.target.ref?.$container === expr.right.target.ref?.$container + ) { + // comparing two fields of the same model + this.writeFieldCondition( + expr.left, + () => { + this.writeFieldReference(expr.right as ReferenceExpr); + }, + 'in' + ); } else { throw new PluginError(name, '"in" operator cannot be used with field references on both sides'); } @@ -223,7 +242,7 @@ export class ExpressionWriter { return this.isFieldAccess(expr.operand); } } - if (isReferenceExpr(expr) && isDataModelField(expr.target.ref) && !this.isPostGuard) { + if (isDataModelFieldReference(expr) && !this.isPostGuard) { return true; } return false; @@ -256,7 +275,15 @@ export class ExpressionWriter { const rightIsFieldAccess = this.isFieldAccess(expr.right); if (leftIsFieldAccess && rightIsFieldAccess) { - throw new PluginError(name, `Comparison between fields are not supported yet`); + if ( + isDataModelFieldReference(expr.left) && + isDataModelFieldReference(expr.right) && + expr.left.target.ref?.$container === expr.right.target.ref?.$container + ) { + // comparing fields from the same model + } else { + throw new PluginError(name, `Comparing fields from different models is not supported`); + } } if (!leftIsFieldAccess && !rightIsFieldAccess) { @@ -358,7 +385,13 @@ export class ExpressionWriter { }); } else { this.writeOperator(operator, fieldAccess, () => { - this.plain(operand); + if (isDataModelFieldReference(operand) && !this.isPostGuard) { + // if operand is a field reference and we're not generating for post-update guard, + // we should generate a field reference (comparing fields in the same model) + this.writeFieldReference(operand); + } else { + this.plain(operand); + } }); } }, !isThisExpr(fieldAccess)); @@ -370,6 +403,15 @@ export class ExpressionWriter { ); } + // https://www.prisma.io/docs/reference/api-reference/prisma-client-reference#compare-columns-in-the-same-table + private writeFieldReference(expr: ReferenceExpr) { + if (!expr.target.ref) { + throw new PluginError(name, `Unresolved reference "${expr.target.$refText}"`); + } + const containingModel = expr.target.ref.$container; + this.writer.write(`db.${lowerCaseFirst(containingModel.name)}.fields.${expr.target.ref.name}`); + } + private isAuthOrAuthMemberAccess(expr: Expression) { return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand)); } diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 54964edc1..bacdc688c 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -70,7 +70,7 @@ export default class PolicyGenerator { sf.addStatements('/* eslint-disable */'); sf.addImportDeclaration({ - namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }], + namedImports: [{ name: 'type QueryContext' }, { name: 'type DbOperations' }, { name: 'hasAllFields' }], moduleSpecifier: `${RUNTIME_PACKAGE}`, }); @@ -487,6 +487,10 @@ export default class PolicyGenerator { name: 'context', type: 'QueryContext', }, + { + name: 'db', + type: 'Record', + }, ], statements, }); diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index b6a6f80bc..46db336ba 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -274,6 +274,13 @@ export default class PrismaSchemaGenerator { } } + if (semver.lt(prismaVersion, '5.0.0')) { + // fieldReference feature is opt-in pre V5 + if (!previewFeatures.includes('fieldReference')) { + previewFeatures.push('fieldReference'); + } + } + if (previewFeatures.length > 0) { const curr = generator.fields.find((f) => f.name === 'previewFeatures'); if (!curr) { diff --git a/packages/sdk/src/utils.ts b/packages/sdk/src/utils.ts index fb2e7e6cf..9543b3242 100644 --- a/packages/sdk/src/utils.ts +++ b/packages/sdk/src/utils.ts @@ -11,6 +11,7 @@ import { InternalAttribute, isArrayExpr, isDataModel, + isDataModelField, isEnumField, isLiteralExpr, isObjectExpr, @@ -20,8 +21,8 @@ import { ReferenceExpr, } from '@zenstackhq/language/ast'; import path from 'path'; -import { PluginOptions } from './types'; import { ExpressionContext } from './constants'; +import { PluginOptions } from './types'; /** * Gets data models that are not ignored @@ -137,6 +138,10 @@ export function isEnumFieldReference(node: AstNode): node is ReferenceExpr { return isReferenceExpr(node) && isEnumField(node.target.ref); } +export function isDataModelFieldReference(node: AstNode): node is ReferenceExpr { + return isReferenceExpr(node) && isDataModelField(node.target.ref); +} + /** * Gets `@@id` fields declared at the data model level */ diff --git a/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts b/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts new file mode 100644 index 000000000..1fbb0bca4 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/field-comparison.test.ts @@ -0,0 +1,166 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; +import { Pool } from 'pg'; + +const DB_NAME = 'field-comparison'; + +describe('WithPolicy: field comparison tests', () => { + let origDir: string; + let prisma: any; + + const pool = new Pool({ user: 'postgres', password: 'abc123' }); + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + beforeEach(async () => { + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + await pool.query(`CREATE DATABASE "${DB_NAME}";`); + }); + + afterEach(async () => { + process.chdir(origDir); + if (prisma) { + await prisma.$disconnect(); + } + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + }); + + it('field comparison success with input check', async () => { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/${DB_NAME}' + } + + generator js { + provider = 'prisma-client-js' + } + + model Model { + id String @id @default(uuid()) + x Int + y Int + + @@allow('create', x > y) + @@allow('read', true) + } + `, + { addPrelude: false } + ); + + prisma = r.prisma; + const db = r.withPolicy(); + await expect(db.model.create({ data: { x: 1, y: 2 } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: 2, y: 1 } })).toResolveTruthy(); + }); + + it('field comparison success with policy check', async () => { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/${DB_NAME}' + } + + generator js { + provider = 'prisma-client-js' + } + + model Model { + id String @id @default(uuid()) + x Int @default(0) + y Int @default(0) + + @@allow('create', x > y) + @@allow('read', true) + } + `, + { addPrelude: false } + ); + + prisma = r.prisma; + const db = r.withPolicy(); + await expect(db.model.create({ data: { x: 1, y: 2 } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: 2, y: 1 } })).toResolveTruthy(); + }); + + it('field in operator success with input check', async () => { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/${DB_NAME}' + } + + generator js { + provider = 'prisma-client-js' + } + + model Model { + id String @id @default(uuid()) + x String + y String[] + + @@allow('create', x in y) + @@allow('read', x in y) + } + `, + { addPrelude: false } + ); + + prisma = r.prisma; + const db = r.withPolicy(); + await expect(db.model.create({ data: { x: 'a', y: ['b', 'c'] } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: 'a', y: ['a', 'c'] } })).toResolveTruthy(); + }); + + it('field in operator success with policy check', async () => { + const r = await loadSchema( + ` + datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/${DB_NAME}' + } + + generator js { + provider = 'prisma-client-js' + } + + model Model { + id String @id @default(uuid()) + x String @default('x') + y String[] + + @@allow('create', x in y) + @@allow('read', x in y) + } + `, + { addPrelude: false } + ); + + prisma = r.prisma; + const db = r.withPolicy(); + await expect(db.model.create({ data: { x: 'a', y: ['b', 'c'] } })).toBeRejectedByPolicy(); + await expect(db.model.create({ data: { x: 'a', y: ['a', 'c'] } })).toResolveTruthy(); + }); + + it('field comparison type error', async () => { + await expect( + loadSchema( + ` + model Model { + id String @id @default(uuid()) + x Int + y String + + @@allow('create', x > y) + @@allow('read', true) + } + ` + ) + ).rejects.toThrow(/invalid operand type/); + }); +});