From ff9951aacc412db740c1648887ca06668a34d6b3 Mon Sep 17 00:00:00 2001 From: Yiming Date: Wed, 26 Jul 2023 11:37:05 +0800 Subject: [PATCH 1/5] refactor: avoid post-read checking for non-nullable to-one relation (#607) --- packages/runtime/src/constants.ts | 5 + .../src/enhancements/policy/handler.ts | 20 +-- .../runtime/src/enhancements/policy/index.ts | 13 ++ .../src/enhancements/policy/policy-utils.ts | 165 ++++++------------ .../src/plugins/prisma/schema-generator.ts | 36 ++-- .../with-policy/deep-nested.test.ts | 78 ++++++++- .../with-policy/nested-to-one.test.ts | 4 +- .../enhancements/with-policy/view.test.ts | 2 +- 8 files changed, 185 insertions(+), 138 deletions(-) diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 459304dd4..5297bcb59 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -49,3 +49,8 @@ export const PRISIMA_TX_FLAG = '$__zenstack_tx'; * Field name for getting current enhancer */ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer'; + +/** + * Minimum Prisma version supported + */ +export const PRISMA_MINIMUM_VERSION = '4.8.0'; diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 0883cdd45..c22dd01c9 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -118,7 +118,7 @@ export class PolicyProxyHandler implements Pr // entity fails access policies const result: any = await this.utils.processWrite(this.model, 'create', args, (dbOps, writeArgs) => { if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`create\`: ${formatObject(writeArgs)}`); + this.logger.info(`[withPolicy] \`create\` ${this.model}: ${formatObject(writeArgs)}`); } return dbOps.create(writeArgs); }); @@ -147,7 +147,7 @@ export class PolicyProxyHandler implements Pr // entity fails access policies const result = await this.utils.processWrite(this.model, 'create', args, (dbOps, writeArgs) => { if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`createMany\`: ${formatObject(writeArgs)}`); + this.logger.info(`[withPolicy] \`createMany\` ${this.model}: ${formatObject(writeArgs)}`); } return dbOps.createMany(writeArgs, skipDuplicates); }); @@ -175,7 +175,7 @@ export class PolicyProxyHandler implements Pr // create fails access policies const result: any = await this.utils.processWrite(this.model, 'update', args, (dbOps, writeArgs) => { if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`update\`: ${formatObject(writeArgs)}`); + this.logger.info(`[withPolicy] \`update\` ${this.model}: ${formatObject(writeArgs)}`); } return dbOps.update(writeArgs); }); @@ -203,7 +203,7 @@ export class PolicyProxyHandler implements Pr // create fails access policies const result = await this.utils.processWrite(this.model, 'updateMany', args, (dbOps, writeArgs) => { if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`updateMany\`: ${formatObject(writeArgs)}`); + this.logger.info(`[withPolicy] \`updateMany\` ${this.model}: ${formatObject(writeArgs)}`); } return dbOps.updateMany(writeArgs); }); @@ -235,7 +235,7 @@ export class PolicyProxyHandler implements Pr // create fails access policies const result: any = await this.utils.processWrite(this.model, 'upsert', args, (dbOps, writeArgs) => { if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`upsert\`: ${formatObject(writeArgs)}`); + this.logger.info(`[withPolicy] \`upsert\` ${this.model}: ${formatObject(writeArgs)}`); } return dbOps.upsert(writeArgs); }); @@ -273,7 +273,7 @@ export class PolicyProxyHandler implements Pr // conduct the deletion if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`delete\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`delete\` ${this.model}:\n${formatObject(args)}`); } await this.modelClient.delete(args); @@ -298,7 +298,7 @@ export class PolicyProxyHandler implements Pr // conduct the deletion if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`deleteMany\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.deleteMany(args); } @@ -314,7 +314,7 @@ export class PolicyProxyHandler implements Pr await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`aggregate\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.aggregate(args); } @@ -330,7 +330,7 @@ export class PolicyProxyHandler implements Pr await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`groupBy\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.groupBy(args); } @@ -343,7 +343,7 @@ export class PolicyProxyHandler implements Pr await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`count\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`count\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.count(args); } diff --git a/packages/runtime/src/enhancements/policy/index.ts b/packages/runtime/src/enhancements/policy/index.ts index 1fbe249f1..5ed4eeea2 100644 --- a/packages/runtime/src/enhancements/policy/index.ts +++ b/packages/runtime/src/enhancements/policy/index.ts @@ -2,6 +2,8 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import path from 'path'; +import semver from 'semver'; +import { PRISMA_MINIMUM_VERSION } from '../../constants'; import { AuthUser, DbClientContract } from '../../types'; import { getDefaultModelMeta } from '../model-meta'; import { makeProxy } from '../proxy'; @@ -53,6 +55,17 @@ export function withPolicy( context?: WithPolicyContext, options?: WithPolicyOptions ): DbClient { + if (!prisma) { + throw new Error('Invalid prisma instance'); + } + + const prismaVer = (prisma as any)._clientVersion; + if (prismaVer && semver.lt(prismaVer, PRISMA_MINIMUM_VERSION)) { + console.warn( + `ZenStack requires Prisma version "${PRISMA_MINIMUM_VERSION}" or higher. Detected version is "${prismaVer}".` + ); + } + const _policy = options?.policy ?? getDefaultPolicy(); const _modelMeta = options?.modelMeta ?? getDefaultModelMeta(); const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 91d1cf78e..1f45a906f 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -14,7 +14,6 @@ import { PrismaErrorCode, TRANSACTION_FIELD_NAME, } from '../../constants'; -import { isPrismaClientKnownRequestError } from '../../error'; import { AuthUser, DbClientContract, @@ -23,7 +22,7 @@ import { PolicyOperationKind, PrismaWriteActionType, } from '../../types'; -import { getPrismaVersion, getVersion } from '../../version'; +import { getVersion } from '../../version'; import { getFields, resolveField } from '../model-meta'; import { NestedWriteVisitor, type NestedWriteVisitorContext } from '../nested-write-vistor'; import type { ModelMeta, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; @@ -36,7 +35,6 @@ import { prismaClientUnknownRequestError, } from '../utils'; import { Logger } from './logger'; -import semver from 'semver'; /** * Access policy enforcement utilities @@ -46,8 +44,6 @@ export class PolicyUtil { // @ts-ignore private readonly logger: Logger; - private supportNestedToOneFilter = false; - constructor( private readonly db: DbClientContract, private readonly modelMeta: ModelMeta, @@ -57,10 +53,6 @@ export class PolicyUtil { private readonly logPrismaQuery?: boolean ) { this.logger = new Logger(db); - - // use Prisma version to detect if we can filter when nested-fetching to-one relation - const prismaVersion = getPrismaVersion(); - this.supportNestedToOneFilter = prismaVersion ? semver.gte(prismaVersion, '4.8.0') : false; } /** @@ -267,15 +259,21 @@ export class PolicyUtil { await this.injectAuthGuard(args, model, 'read'); - // recursively inject read guard conditions into the query args - await this.injectNestedReadConditions(model, args); + // recursively inject read guard conditions into nested select, include, and _count + const hoistedConditions = await this.injectNestedReadConditions(model, args); + + // the injection process may generate conditions that need to be hoisted to the toplevel, + // if so, merge it with the existing where + if (hoistedConditions && Object.keys(hoistedConditions).length > 0) { + args.where = this.and(args.where, ...hoistedConditions); + } if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`findMany\`:\n${formatObject(args)}`); + this.logger.info(`[withPolicy] \`findMany\` ${model}:\n${formatObject(args)}`); } const result: any[] = await this.db[model].findMany(args); - await this.postProcessForRead(result, model, args, 'read'); + this.postProcessForRead(result, args); return result; } @@ -284,7 +282,6 @@ export class PolicyUtil { async flattenGeneratedUniqueField(model: string, args: any) { // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)]; - let flattened = false; if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { for (const [field, value] of Object.entries(args)) { if (uniqueConstraints[field] && typeof value === 'object') { @@ -292,21 +289,15 @@ export class PolicyUtil { args[f] = v; } delete args[field]; - flattened = true; } } } - - if (flattened) { - // DEBUG - // this.logger.info(`Filter flattened: ${JSON.stringify(args)}`); - } } - private async injectNestedReadConditions(model: string, args: any) { + private async injectNestedReadConditions(model: string, args: any): Promise { const injectTarget = args.select ?? args.include; if (!injectTarget) { - return; + return []; } if (injectTarget._count !== undefined) { @@ -340,7 +331,8 @@ export class PolicyUtil { } } - const idFields = this.getIdFields(model); + // collect filter conditions that should be hoisted to the toplevel + const hoistedConditions: any[] = []; for (const field of getModelFields(injectTarget)) { const fieldInfo = resolveField(this.modelMeta, model, field); @@ -349,35 +341,41 @@ export class PolicyUtil { continue; } + let hoisted: any; + if ( fieldInfo.isArray || - // if Prisma version is high enough to support filtering directly when - // fetching a nullable to-one relation, let's do it that way + // Injecting where at include/select level for nullable to-one relation is supported since Prisma 4.8.0 // https://github.com/prisma/prisma/discussions/20350 - (this.supportNestedToOneFilter && fieldInfo.isOptional) + fieldInfo.isOptional ) { if (typeof injectTarget[field] !== 'object') { injectTarget[field] = {}; } // inject extra condition for to-many or nullable to-one relation await this.injectAuthGuard(injectTarget[field], fieldInfo.type, 'read'); - - // recurse - await this.injectNestedReadConditions(fieldInfo.type, injectTarget[field]); } else { - // there's no way of injecting condition for to-one relation, so if there's - // "select" clause we make sure 'id' fields are selected and check them against - // query result; nothing needs to be done for "include" clause because all - // fields are already selected - if (injectTarget[field]?.select) { - for (const idField of idFields) { - if (injectTarget[field].select[idField.name] !== true) { - injectTarget[field].select[idField.name] = true; - } - } + // hoist non-nullable to-one filter to the parent level + const guard = this.getAuthGuard(fieldInfo.type, 'read'); + if (guard !== true) { + // use "and" to resolve boolean values + hoisted = this.and(guard); } } + + // recurse + const subHoisted = await this.injectNestedReadConditions(fieldInfo.type, injectTarget[field]); + + if (subHoisted.length > 0) { + hoisted = this.and(hoisted, ...subHoisted); + } + + if (hoisted !== undefined) { + hoistedConditions.push({ [field]: hoisted }); + } } + + return hoistedConditions; } /** @@ -385,80 +383,31 @@ export class PolicyUtil { * (which can't be trimmed at query time) and removes fields that should be * omitted. */ - async postProcessForRead(data: any, model: string, args: any, operation: PolicyOperationKind) { - await Promise.all( - enumerate(data).map((entityData) => this.postProcessSingleEntityForRead(entityData, model, args, operation)) - ); - } - - private async postProcessSingleEntityForRead(data: any, model: string, args: any, operation: PolicyOperationKind) { - if (typeof data !== 'object' || !data) { - return; - } - - // strip auxiliary fields - for (const auxField of AUXILIARY_FIELDS) { - if (auxField in data) { - delete data[auxField]; + private postProcessForRead(data: any, args: any) { + for (const entityData of enumerate(data)) { + if (typeof entityData !== 'object' || !entityData) { + return; } - } - - const injectTarget = args.select ?? args.include; - if (!injectTarget) { - return; - } - // recurse into nested entities - for (const field of Object.keys(injectTarget)) { - const fieldData = data[field]; - if (typeof fieldData !== 'object' || !fieldData) { - continue; + // strip auxiliary fields + for (const auxField of AUXILIARY_FIELDS) { + if (auxField in entityData) { + delete entityData[auxField]; + } } - const fieldInfo = resolveField(this.modelMeta, model, field); - if (fieldInfo) { - if ( - fieldInfo.isDataModel && - !fieldInfo.isArray && - // if Prisma version supports filtering nullable to-one relation, no need to further check - !(this.supportNestedToOneFilter && fieldInfo.isOptional) - ) { - // to-one relation data cannot be trimmed by injected guards, we have to - // post-check them - const ids = this.getEntityIds(fieldInfo.type, fieldData); - - if (Object.keys(ids).length !== 0) { - if (this.logger.enabled('info')) { - this.logger.info( - `Validating read of to-one relation: ${fieldInfo.type}#${formatObject(ids)}` - ); - } + const injectTarget = args.select ?? args.include; + if (!injectTarget) { + return; + } - try { - await this.checkPolicyForFilter(fieldInfo.type, ids, operation, this.db); - } catch (err) { - if ( - isPrismaClientKnownRequestError(err) && - err.code === PrismaErrorCode.CONSTRAINED_FAILED - ) { - // denied by policy - if (fieldInfo.isOptional) { - // if the relation is optional, just nullify it - data[field] = null; - } else { - // otherwise reject - throw err; - } - } else { - // unknown error - throw err; - } - } - } + // recurse into nested entities + for (const field of Object.keys(injectTarget)) { + const fieldData = entityData[field]; + if (typeof fieldData !== 'object' || !fieldData) { + continue; } - - // recurse - await this.postProcessForRead(fieldData, fieldInfo.type, injectTarget[field], operation); + this.postProcessForRead(fieldData, injectTarget[field]); } } } @@ -655,7 +604,7 @@ export class PolicyUtil { const query = { where: filter, select }; if (this.shouldLogQuery) { this.logger.info( - `[withPolicy] \`findMany\` for fetching pre-update entities:\n${formatObject(args)}` + `[withPolicy] \`findMany\` ${model} for fetching pre-update entities:\n${formatObject(args)}` ); } const entities = await this.db[model].findMany(query); diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 4bd5b896c..b6a6f80bc 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -20,6 +20,7 @@ import { LiteralExpr, Model, } from '@zenstackhq/language/ast'; +import { PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime'; import { analyzePolicies, getDataModels, @@ -80,9 +81,17 @@ export default class PrismaSchemaGenerator { `; async generate(model: Model, options: PluginOptions, config?: Record) { - const prisma = new PrismaModel(); const warnings: string[] = []; + const prismaVersion = getPrismaVersion(); + if (prismaVersion && semver.lt(prismaVersion, PRISMA_MINIMUM_VERSION)) { + warnings.push( + `ZenStack requires Prisma version "${PRISMA_MINIMUM_VERSION}" or higher. Detected version is "${prismaVersion}".` + ); + } + + const prisma = new PrismaModel(); + for (const decl of model.declarations) { switch (decl.$type) { case DataSource: @@ -252,26 +261,23 @@ export default class PrismaSchemaGenerator { if (provider?.value === 'prisma-client-js') { const prismaVersion = getPrismaVersion(); if (prismaVersion) { - let previewFeatures = generator.fields.find((f) => f.name === 'previewFeatures'); - if (!previewFeatures) { - previewFeatures = { name: 'previewFeatures', value: [] }; - generator.fields.push(previewFeatures); - } - if (!Array.isArray(previewFeatures.value)) { + const previewFeatures = generator.fields.find((f) => f.name === 'previewFeatures')?.value ?? []; + + if (!Array.isArray(previewFeatures)) { throw new PluginError(name, 'option "previewFeatures" must be an array'); } - if (semver.lt(prismaVersion, '4.7.0')) { - // interactiveTransactions feature is opt-in before 4.7.0 - if (!previewFeatures.value.includes('interactiveTransactions')) { - previewFeatures.value.push('interactiveTransactions'); + if (semver.lt(prismaVersion, '5.0.0')) { + // extendedWhereUnique feature is opt-in pre V5 + if (!previewFeatures.includes('extendedWhereUnique')) { + previewFeatures.push('extendedWhereUnique'); } } - if (semver.gte(prismaVersion, '4.8.0') && semver.lt(prismaVersion, '5.0.0')) { - // extendedWhereUnique feature is opt-in during [4.8.0, 5.0.0) - if (!previewFeatures.value.includes('extendedWhereUnique')) { - previewFeatures.value.push('extendedWhereUnique'); + if (previewFeatures.length > 0) { + const curr = generator.fields.find((f) => f.name === 'previewFeatures'); + if (!curr) { + generator.fields.push({ name: 'previewFeatures', value: previewFeatures }); } } } diff --git a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts index 8bd5bd6cc..2d7326ed8 100644 --- a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts +++ b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts @@ -10,10 +10,12 @@ describe('With Policy:deep nested', () => { model M1 { myId String @id @default(cuid()) m2 M2? + value Int @default(0) @@allow('all', true) @@deny('create', m2.m4?[value == 100]) @@deny('update', m2.m4?[value == 101]) + @@deny('read', value == 100) } model M2 { @@ -59,20 +61,92 @@ describe('With Policy:deep nested', () => { `; let db: WeakDbClientContract; + let prisma: WeakDbClientContract; beforeAll(async () => { origDir = path.resolve('.'); }); beforeEach(async () => { - const { withPolicy } = await loadSchema(model); - db = withPolicy(); + const params = await loadSchema(model, { logPrismaQuery: true }); + db = params.withPolicy(); + prisma = params.prisma; }); afterEach(() => { process.chdir(origDir); }); + it('read', async () => { + await prisma.m1.create({ + data: { + myId: '1', + m2: { + create: { + value: 1, + m3: { + create: { id: '3-1', value: 31 }, + }, + m4: { + create: [{ value: 41 }, { value: 42 }], + }, + }, + }, + }, + }); + // all readable + let r = await db.m1.findUnique({ + where: { myId: '1' }, + include: { m2: { include: { m3: true, m4: true } } }, + }); + expect(r.m2.m3).toBeTruthy(); + expect(r.m2.m4).toHaveLength(2); + r = await db.m3.findUnique({ where: { id: '3-1' }, include: { m2: { include: { m1: true } } } }); + expect(r.m2.m1).toBeTruthy(); + + await prisma.m1.create({ + data: { + myId: '2', + m2: { + create: { + value: 1, + m3: { + create: { value: 200 }, + }, + m4: { + create: [{ value: 22 }, { value: 200 }], + }, + }, + }, + }, + }); + // check filtered + r = await db.m1.findUnique({ + where: { myId: '2' }, + include: { m2: { include: { m3: true, m4: true } } }, + }); + expect(r.m2.m3).toBeNull(); + expect(r.m2.m4).toHaveLength(1); + + await prisma.m1.create({ + data: { + myId: '3', + value: 100, + m2: { + create: { + value: 1, + m3: { + create: { id: '3-2', value: 31 }, + }, + }, + }, + }, + }); + // check hoisted filtering, due to m1 is not readable + r = await db.m3.findUnique({ where: { id: '3-2' }, include: { m2: { include: { m1: true } } } }); + expect(r).toBeNull(); + }); + it('create', async () => { await expect( db.m1.create({ diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 45a0e765d..59718a84c 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -92,8 +92,8 @@ describe('With Policy:nested to-one', () => { }); const db = withPolicy(); - await expect(db.m2.findUnique({ where: { id: '1' }, include: { m1: true } })).toBeRejectedByPolicy(); - await expect(db.m2.findMany({ include: { m1: true } })).toBeRejectedByPolicy(); + await expect(db.m2.findUnique({ where: { id: '1' }, include: { m1: true } })).toResolveFalsy(); + await expect(db.m2.findMany({ include: { m1: true } })).resolves.toHaveLength(0); await prisma.m1.update({ where: { id: '1' }, data: { value: 1 } }); await expect(db.m2.findMany({ include: { m1: true } })).toResolveTruthy(); diff --git a/tests/integration/tests/enhancements/with-policy/view.test.ts b/tests/integration/tests/enhancements/with-policy/view.test.ts index 1af13a65c..f5abe6439 100644 --- a/tests/integration/tests/enhancements/with-policy/view.test.ts +++ b/tests/integration/tests/enhancements/with-policy/view.test.ts @@ -100,6 +100,6 @@ describe('View Policy Test', () => { expect(r1.user).toBeTruthy(); // user not readable - await expect(db.userInfo.findFirst({ include: { user: true } })).toBeRejectedByPolicy(); + await expect(db.userInfo.findFirst({ include: { user: true } })).toResolveFalsy(); }); }); From 7757b5772c410d0367215f75e5b25062e0033a87 Mon Sep 17 00:00:00 2001 From: Yiming Date: Mon, 31 Jul 2023 11:35:02 +0800 Subject: [PATCH 2/5] merge main back to dev (#611) Co-authored-by: Jiasheng --- package.json | 2 +- packages/language/package.json | 2 +- packages/plugins/openapi/package.json | 2 +- packages/plugins/swr/package.json | 2 +- packages/plugins/tanstack-query/package.json | 2 +- packages/plugins/trpc/package.json | 2 +- packages/runtime/package.json | 2 +- packages/schema/package.json | 2 +- .../schema/src/plugins/model-meta/index.ts | 24 ++++++--- packages/sdk/package.json | 2 +- packages/server/package.json | 2 +- packages/testtools/package.json | 2 +- .../tests/regression/issues.test.ts | 54 +++++++++++++++++++ 13 files changed, 83 insertions(+), 17 deletions(-) diff --git a/package.json b/package.json index 7a161fe47..556149eac 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/language/package.json b/packages/language/package.json index d8b521145..5c2e8bc97 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 1bf11bc76..175cd84b7 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index 642252a55..083912f6c 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index fa7589bf9..130c4b051 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 303b5e81f..3f5723880 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 55e3f8ca7..f310a999d 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/schema/package.json b/packages/schema/package.json index 6121048bd..8de326b9a 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "A toolkit for building secure CRUD apps with Next.js + Typescript", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index 892f57c36..ce6d82785 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -11,6 +11,7 @@ import type { RuntimeAttribute } from '@zenstackhq/runtime'; import { createProject, emitProject, + getAttributeArg, getAttributeArgs, getDataModels, getLiteral, @@ -182,14 +183,25 @@ function isRelationOwner(field: DataModelField, backLink: DataModelField | undef return false; } - if (hasAttribute(field, '@relation')) { - // this field has `@relation` attribute + if (!backLink) { + // CHECKME: can this really happen? return true; - } else if (!backLink || !hasAttribute(backLink, '@relation')) { - // if the opposite side field doesn't have `@relation` attribute either, - // it's an implicit many-to-many relation, both sides are owners + } + + if (!hasAttribute(field, '@relation') && !hasAttribute(backLink, '@relation')) { + // if neither side has `@relation` attribute, it's an implicit many-to-many relation, + // both sides are owners return true; - } else { + } + + return holdsForeignKey(field); +} + +function holdsForeignKey(field: DataModelField) { + const relation = field.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (!relation) { return false; } + const fields = getAttributeArg(relation, 'fields'); + return !!fields; } diff --git a/packages/sdk/package.json b/packages/sdk/package.json index e429e8107..d9e744361 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/server/package.json b/packages/server/package.json index a39d0f358..030f70d3d 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 63513bd2a..1836a12ef 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.0.0-beta.12", + "version": "1.0.0-beta.13", "description": "ZenStack Test Tools", "main": "index.js", "publishConfig": { diff --git a/tests/integration/tests/regression/issues.test.ts b/tests/integration/tests/regression/issues.test.ts index 88551a62d..afb709667 100644 --- a/tests/integration/tests/regression/issues.test.ts +++ b/tests/integration/tests/regression/issues.test.ts @@ -300,4 +300,58 @@ describe('GitHub issues regression', () => { }, }); }); + + it('issue 609', async () => { + const { withPolicy, prisma } = await loadSchema( + ` + model User { + id String @id @default(cuid()) + comments Comment[] + } + + model Comment { + id String @id @default(cuid()) + parentCommentId String? + replies Comment[] @relation("CommentToComment") + parent Comment? @relation("CommentToComment", fields: [parentCommentId], references: [id]) + comment String + author User @relation(fields: [authorId], references: [id]) + authorId String + + @@allow('read,create', true) + @@allow('update,delete', auth() == author) + } + ` + ); + + await prisma.user.create({ + data: { + id: '1', + comments: { + create: { + id: '1', + comment: 'Comment 1', + }, + }, + }, + }); + + await prisma.user.create({ + data: { + id: '2', + }, + }); + + // connecting a child comment from a different user to a parent comment should succeed + const db = withPolicy({ id: '2' }); + await expect( + db.comment.create({ + data: { + comment: 'Comment 2', + author: { connect: { id: '2' } }, + parent: { connect: { id: '1' } }, + }, + }) + ).toResolveTruthy(); + }); }); From f456a97078e3e42945c2a3397b44f6b2c307354c Mon Sep 17 00:00:00 2001 From: Yiming Date: Fri, 4 Aug 2023 19:30:00 +0800 Subject: [PATCH 3/5] refactor to policy check Don't rely on aux fields for policy check anymore --- .github/workflows/build-test.yml | 14 + package.json | 4 +- packages/runtime/package.json | 1 - packages/runtime/src/constants.ts | 20 +- .../src/enhancements/model-data-visitor.ts | 43 + .../runtime/src/enhancements/model-meta.ts | 3 +- .../src/enhancements/nested-write-vistor.ts | 140 +- packages/runtime/src/enhancements/omit.ts | 2 +- .../src/enhancements/policy/handler.ts | 982 +++++++++++--- .../src/enhancements/policy/policy-utils.ts | 898 +++++-------- packages/runtime/src/enhancements/proxy.ts | 12 +- packages/runtime/src/enhancements/types.ts | 8 + packages/runtime/src/enhancements/utils.ts | 4 +- packages/runtime/src/error.ts | 13 +- packages/runtime/src/types.ts | 8 +- .../access-policy/policy-guard-generator.ts | 181 ++- .../schema/src/plugins/model-meta/index.ts | 39 + .../typescript-expression-transformer.ts | 2 +- packages/server/src/api/rpc/index.ts | 45 +- packages/testtools/src/schema.ts | 2 +- pnpm-lock.yaml | 159 ++- tests/integration/package.json | 2 + .../with-policy/deep-nested.test.ts | 15 +- .../with-policy/empty-policy.test.ts | 20 +- .../with-policy/nested-to-many.test.ts | 2 +- .../with-policy/nested-to-one.test.ts | 8 +- .../with-policy/post-update.test.ts | 5 +- .../enhancements/with-policy/postgres.test.ts | 526 ++++++++ .../enhancements/with-policy/refactor.test.ts | 1143 +++++++++++++++++ .../with-policy/toplevel-operations.test.ts | 3 +- .../tests/schema/refactor-pg.zmodel | 100 ++ tests/integration/tests/schema/todo-pg.zmodel | 152 +++ tests/integration/utils/jest-ext.ts | 20 +- 33 files changed, 3717 insertions(+), 859 deletions(-) create mode 100644 packages/runtime/src/enhancements/model-data-visitor.ts create mode 100644 tests/integration/tests/enhancements/with-policy/postgres.test.ts create mode 100644 tests/integration/tests/enhancements/with-policy/refactor.test.ts create mode 100644 tests/integration/tests/schema/refactor-pg.zmodel create mode 100644 tests/integration/tests/schema/todo-pg.zmodel diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 941a3684f..4857fda3f 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -15,6 +15,20 @@ jobs: build-test: runs-on: buildjet-8vcpu-ubuntu-2204 + services: + postgres: + image: postgres + env: + POSTGRES_PASSWORD: abc123 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 5432:5432 + strategy: matrix: node-version: [18.x] diff --git a/package.json b/package.json index 556149eac..4c35012c0 100644 --- a/package.json +++ b/package.json @@ -4,8 +4,8 @@ "description": "", "scripts": { "build": "pnpm -r build", - "test": "ZENSTACK_TEST=1 pnpm -r run test --silent", - "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent", + "test": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", + "test-ci": "ZENSTACK_TEST=1 pnpm -r run test --silent --forceExit", "lint": "pnpm -r lint", "publish-all": "pnpm --filter \"./packages/**\" -r publish --access public", "publish-preview": "pnpm --filter \"./packages/**\" -r publish --force --registry http://localhost:4873" diff --git a/packages/runtime/package.json b/packages/runtime/package.json index f310a999d..7cf48b668 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -45,7 +45,6 @@ "linkDirectory": true }, "dependencies": { - "@paralleldrive/cuid2": "^2.2.0", "@types/bcryptjs": "^2.4.2", "bcryptjs": "^2.4.3", "buffer": "^6.0.3", diff --git a/packages/runtime/src/constants.ts b/packages/runtime/src/constants.ts index 5297bcb59..fb3644e60 100644 --- a/packages/runtime/src/constants.ts +++ b/packages/runtime/src/constants.ts @@ -37,13 +37,31 @@ export enum CrudFailureReason { * Prisma error codes used */ export enum PrismaErrorCode { + /** + * Unique constraint failed + */ + UNIQUE_CONSTRAINT_FAILED = 'P2002', + + /** + * A constraint failed on the database + */ CONSTRAINED_FAILED = 'P2004', + + /** + * The required connected records were not found + */ + REQUIRED_CONNECTED_RECORD_NOT_FOUND = 'P2018', + + /** + * An operation failed because it depends on one or more records that were required but not found + */ + DEPEND_ON_RECORD_NOT_FOUND = 'P2025', } /** * Field name for storing in-transaction flag */ -export const PRISIMA_TX_FLAG = '$__zenstack_tx'; +export const PRISMA_TX_FLAG = '$__zenstack_tx'; /** * Field name for getting current enhancer diff --git a/packages/runtime/src/enhancements/model-data-visitor.ts b/packages/runtime/src/enhancements/model-data-visitor.ts new file mode 100644 index 000000000..50a6dfbe3 --- /dev/null +++ b/packages/runtime/src/enhancements/model-data-visitor.ts @@ -0,0 +1,43 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ +import { resolveField } from './model-meta'; +import { ModelMeta } from './types'; + +/** + * Callback for @see ModelDataVisitor. + */ +export type ModelDataVisitorCallback = (model: string, data: any, scalarData: any) => void; + +/** + * Visitor that traverses data returned by a Prisma query. + */ +export class ModelDataVisitor { + constructor(private modelMeta: ModelMeta) {} + + /** + * Visits the given model data. + */ + visit(model: string, data: any, callback: ModelDataVisitorCallback) { + if (!data || typeof data !== 'object') { + return; + } + + const scalarData: Record = {}; + const subTasks: Array<{ model: string; data: any }> = []; + + for (const [k, v] of Object.entries(data)) { + const field = resolveField(this.modelMeta, model, k); + if (field && field.isDataModel) { + if (field.isArray && Array.isArray(v)) { + subTasks.push(...v.map((item) => ({ model: field.type, data: item }))); + } else { + subTasks.push({ model: field.type, data: v }); + } + } else { + scalarData[k] = v; + } + } + + callback(model, data, scalarData); + subTasks.forEach(({ model, data }) => this.visit(model, data, callback)); + } +} diff --git a/packages/runtime/src/enhancements/model-meta.ts b/packages/runtime/src/enhancements/model-meta.ts index ee480db5c..83eef9a64 100644 --- a/packages/runtime/src/enhancements/model-meta.ts +++ b/packages/runtime/src/enhancements/model-meta.ts @@ -1,6 +1,7 @@ /* eslint-disable @typescript-eslint/no-var-requires */ import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; +import { FieldInfo } from '../types'; import { ModelMeta } from './types'; /** @@ -26,7 +27,7 @@ export function getDefaultModelMeta(): ModelMeta { /** * Resolves a model field to its metadata. Returns undefined if not found. */ -export function resolveField(modelMeta: ModelMeta, model: string, field: string) { +export function resolveField(modelMeta: ModelMeta, model: string, field: string): FieldInfo | undefined { return modelMeta.fields[lowerCaseFirst(model)][field]; } diff --git a/packages/runtime/src/enhancements/nested-write-vistor.ts b/packages/runtime/src/enhancements/nested-write-vistor.ts index 5b696bbd0..226773242 100644 --- a/packages/runtime/src/enhancements/nested-write-vistor.ts +++ b/packages/runtime/src/enhancements/nested-write-vistor.ts @@ -29,38 +29,56 @@ export type NestedWriteVisitorContext = { }; /** - * NestedWriteVisitor's callback actions + * NestedWriteVisitor's callback actions. A call back function should return true or void to indicate + * that the visitor should continue traversing its children, or false to stop. It can also return an object + * to let the visitor traverse it instead of its original children. */ export type NestedWriterVisitorCallback = { - create?: (model: string, args: any[], context: NestedWriteVisitorContext) => Promise; + create?: (model: string, args: any[], context: NestedWriteVisitorContext) => Promise; + + createMany?: ( + model: string, + args: { data: any; skipDuplicates?: boolean }, + context: NestedWriteVisitorContext + ) => Promise; connectOrCreate?: ( model: string, args: { where: object; create: any }, context: NestedWriteVisitorContext - ) => Promise; + ) => Promise; + + connect?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; - connect?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; + disconnect?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; - disconnect?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; + set?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; - update?: (model: string, args: { where: object; data: any }, context: NestedWriteVisitorContext) => Promise; + update?: (model: string, args: object, context: NestedWriteVisitorContext) => Promise; updateMany?: ( model: string, args: { where?: object; data: any }, context: NestedWriteVisitorContext - ) => Promise; + ) => Promise; upsert?: ( model: string, args: { where: object; create: any; update: any }, context: NestedWriteVisitorContext - ) => Promise; + ) => Promise; - delete?: (model: string, args: object | boolean, context: NestedWriteVisitorContext) => Promise; + delete?: ( + model: string, + args: object | boolean, + context: NestedWriteVisitorContext + ) => Promise; - deleteMany?: (model: string, args: any | object, context: NestedWriteVisitorContext) => Promise; + deleteMany?: ( + model: string, + args: any | object, + context: NestedWriteVisitorContext + ) => Promise; field?: ( field: FieldInfo, @@ -71,7 +89,7 @@ export type NestedWriterVisitorCallback = { }; /** - * Recursive visitor for nested write (create/update) payload + * Recursive visitor for nested write (create/update) payload. */ export class NestedWriteVisitor { constructor(private readonly modelMeta: ModelMeta, private readonly callback: NestedWriterVisitorCallback) {} @@ -91,7 +109,6 @@ export class NestedWriteVisitor { } let topData = args; - // const topWhere = { ...topData.where }; switch (action) { // create has its data wrapped in 'data' field @@ -120,41 +137,50 @@ export class NestedWriteVisitor { return; } - const isToOneUpdate = field?.isDataModel && !field.isArray; const context = { parent, field, nestingPath: [...nestingPath] }; + const toplevel = field == undefined; // visit payload switch (action) { case 'create': context.nestingPath.push({ field, model, where: {}, unique: false }); for (const item of enumerate(data)) { + let callbackResult: any; if (this.callback.create) { - await this.callback.create(model, item, context); + callbackResult = await this.callback.create(model, item, context); + } + if (callbackResult !== false) { + const subPayload = typeof callbackResult === 'object' ? callbackResult : item; + await this.visitSubPayload(model, action, subPayload, context.nestingPath); } - await this.visitSubPayload(model, action, item, context.nestingPath); } break; case 'createMany': - // skip the 'data' layer so as to keep consistency with 'create' - if (data.data) { + if (data) { context.nestingPath.push({ field, model, where: {}, unique: false }); - for (const item of enumerate(data.data)) { - if (this.callback.create) { - await this.callback.create(model, item, context); - } - await this.visitSubPayload(model, action, item, context.nestingPath); + let callbackResult: any; + if (this.callback.createMany) { + callbackResult = await this.callback.createMany(model, data, context); + } + if (callbackResult !== false) { + const subPayload = typeof callbackResult === 'object' ? callbackResult : data.data; + await this.visitSubPayload(model, action, subPayload, context.nestingPath); } } break; case 'connectOrCreate': - context.nestingPath.push({ field, model, where: data.where, unique: true }); + context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + let callbackResult: any; if (this.callback.connectOrCreate) { - await this.callback.connectOrCreate(model, item, context); + callbackResult = await this.callback.connectOrCreate(model, item, context); + } + if (callbackResult !== false) { + const subPayload = typeof callbackResult === 'object' ? callbackResult : item.create; + await this.visitSubPayload(model, action, subPayload, context.nestingPath); } - await this.visitSubPayload(model, action, item.create, context.nestingPath); } break; @@ -188,44 +214,76 @@ export class NestedWriteVisitor { } break; + case 'set': + if (this.callback.set) { + context.nestingPath.push({ field, model, where: {}, unique: false }); + await this.callback.set(model, data, context); + } + break; + case 'update': context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + let callbackResult: any; if (this.callback.update) { - await this.callback.update(model, item, context); + callbackResult = await this.callback.update(model, item, context); + } + if (callbackResult !== false) { + const subPayload = + typeof callbackResult === 'object' + ? callbackResult + : typeof item.data === 'object' + ? item.data + : item; + await this.visitSubPayload(model, action, subPayload, context.nestingPath); } - const payload = isToOneUpdate ? item : item.data; - await this.visitSubPayload(model, action, payload, context.nestingPath); } break; case 'updateMany': context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + let callbackResult: any; if (this.callback.updateMany) { - await this.callback.updateMany(model, item, context); + callbackResult = await this.callback.updateMany(model, item, context); + } + if (callbackResult !== false) { + const subPayload = typeof callbackResult === 'object' ? callbackResult : item; + await this.visitSubPayload(model, action, subPayload, context.nestingPath); } - await this.visitSubPayload(model, action, item, context.nestingPath); } break; case 'upsert': { - context.nestingPath.push({ field, model, where: data.where, unique: true }); + context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { + let callbackResult: any; if (this.callback.upsert) { - await this.callback.upsert(model, item, context); + callbackResult = await this.callback.upsert(model, item, context); + } + if (callbackResult !== false) { + if (typeof callbackResult === 'object') { + await this.visitSubPayload(model, action, callbackResult, context.nestingPath); + } else { + await this.visitSubPayload(model, action, item.create, context.nestingPath); + await this.visitSubPayload(model, action, item.update, context.nestingPath); + } } - await this.visitSubPayload(model, action, item.create, context.nestingPath); - await this.visitSubPayload(model, action, item.update, context.nestingPath); } break; } case 'delete': { if (this.callback.delete) { - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { - await this.callback.delete(model, item, context); + const newContext = { + ...context, + nestingPath: [ + ...context.nestingPath, + { field, model, where: toplevel ? item.where : item, unique: false }, + ], + }; + await this.callback.delete(model, item, newContext); } } break; @@ -233,9 +291,15 @@ export class NestedWriteVisitor { case 'deleteMany': if (this.callback.deleteMany) { - context.nestingPath.push({ field, model, where: data.where, unique: false }); for (const item of enumerate(data)) { - await this.callback.deleteMany(model, item, context); + const newContext = { + ...context, + nestingPath: [ + ...context.nestingPath, + { field, model, where: toplevel ? item.where : item, unique: false }, + ], + }; + await this.callback.deleteMany(model, item, newContext); } } break; diff --git a/packages/runtime/src/enhancements/omit.ts b/packages/runtime/src/enhancements/omit.ts index 2b3f455c9..a23f1e7d3 100644 --- a/packages/runtime/src/enhancements/omit.ts +++ b/packages/runtime/src/enhancements/omit.ts @@ -12,7 +12,7 @@ import { enumerate, getModelFields } from './utils'; */ export type WithOmitOptions = { /** - * Model metatadata + * Model metadata */ modelMeta?: ModelMeta; }; diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index c22dd01c9..c9cdb98a7 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -1,13 +1,26 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { CrudFailureReason } from '../../constants'; -import { AuthUser, DbClientContract, PolicyOperationKind } from '../../types'; -import { BatchResult, PrismaProxyHandler } from '../proxy'; +import { upperCaseFirst } from 'upper-case-first'; +import { fromZodError } from 'zod-validation-error'; +import { CrudFailureReason, PRISMA_TX_FLAG } from '../../constants'; +import { AuthUser, DbClientContract, DbOperations, FieldInfo, PolicyOperationKind } from '../../types'; +import { ModelDataVisitor } from '../model-data-visitor'; +import { resolveField } from '../model-meta'; +import { NestedWriteVisitor, NestedWriteVisitorContext } from '../nested-write-vistor'; +import { PrismaProxyHandler } from '../proxy'; import type { ModelMeta, PolicyDef, ZodSchemas } from '../types'; -import { formatObject, prismaClientValidationError } from '../utils'; +import { enumerate, formatObject, getIdFields, prismaClientValidationError } from '../utils'; import { Logger } from './logger'; import { PolicyUtil } from './policy-utils'; +// a record for post-write policy check +type PostWriteCheckRecord = { + model: string; + operation: PolicyOperationKind; + uniqueFilter: any; + preValue?: any; +}; + /** * Prisma proxy handler for injecting access policy check. */ @@ -31,7 +44,7 @@ export class PolicyProxyHandler implements Pr this.policy, this.zodSchemas, this.user, - this.logPrismaQuery + this.shouldLogQuery ); } @@ -39,6 +52,10 @@ export class PolicyProxyHandler implements Pr return this.prisma[this.model]; } + //#region Find + + // find operations behaves as if the entities that don't match access policies don't exist + async findUnique(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); @@ -47,60 +64,86 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - const guard = this.utils.getAuthGuard(this.model, 'read'); - if (guard === false) { + args = this.utils.clone(args); + if (!(await this.utils.injectForRead(this.model, args))) { return null; } - const entities = await this.utils.readWithCheck(this.model, args); - return entities[0] ?? null; + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`); + } + const result = await this.modelClient.findUnique(args); + this.utils.postProcessForRead(result); + return result; } async findUniqueOrThrow(args: any) { - const guard = this.utils.getAuthGuard(this.model, 'read'); - if (guard === false) { - throw this.utils.notFound(this.model); + if (!args) { + throw prismaClientValidationError(this.prisma, 'query argument is required'); + } + if (!args.where) { + throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - const entity = await this.findUnique(args); - if (!entity) { + args = this.utils.clone(args); + if (!(await this.utils.injectForRead(this.model, args))) { throw this.utils.notFound(this.model); } - return entity; + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`); + } + const result = await this.modelClient.findUniqueOrThrow(args); + this.utils.postProcessForRead(result); + return result; } async findFirst(args: any) { - const guard = this.utils.getAuthGuard(this.model, 'read'); - if (guard === false) { + args = args ? this.utils.clone(args) : {}; + if (!(await this.utils.injectForRead(this.model, args))) { return null; } - const entities = await this.utils.readWithCheck(this.model, args); - return entities[0] ?? null; + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`); + } + const result = await this.modelClient.findFirst(args); + this.utils.postProcessForRead(result); + return result; } async findFirstOrThrow(args: any) { - const guard = this.utils.getAuthGuard(this.model, 'read'); - if (guard === false) { + args = args ? this.utils.clone(args) : {}; + if (!(await this.utils.injectForRead(this.model, args))) { throw this.utils.notFound(this.model); } - const entity = await this.findFirst(args); - if (!entity) { - throw this.utils.notFound(this.model); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`); } - return entity; + const result = await this.modelClient.findFirstOrThrow(args); + this.utils.postProcessForRead(result); + return result; } async findMany(args: any) { - const guard = this.utils.getAuthGuard(this.model, 'read'); - if (guard === false) { + args = args ? this.utils.clone(args) : {}; + if (!(await this.utils.injectForRead(this.model, args))) { return []; } - return this.utils.readWithCheck(this.model, args); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`); + } + const result = await this.modelClient.findMany(args); + this.utils.postProcessForRead(result); + return result; } + //#endregion + + //#region Create + async create(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); @@ -109,52 +152,352 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.tryReject('create'); + await this.utils.tryReject(this.model, 'create'); const origArgs = args; args = this.utils.clone(args); - // use a transaction to wrap the write so it can be reverted if the created - // entity fails access policies - const result: any = await this.utils.processWrite(this.model, 'create', args, (dbOps, writeArgs) => { - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`create\` ${this.model}: ${formatObject(writeArgs)}`); + // static input policy check for top-level create data + const inputCheck = this.utils.checkInputGuard(this.model, args.data, 'create'); + if (inputCheck === false) { + throw this.utils.deniedByPolicy(this.model, 'create'); + } + + const hasNestedCreateOrConnect = await this.hasNestedCreateOrConnect(args); + + const { result, error } = await this.transaction(async (tx) => { + if ( + // MUST check true here since inputCheck can be undefined (meaning static input check not possible) + inputCheck === true && + // simple create: no nested create/connect + !hasNestedCreateOrConnect + ) { + // there's no nested write and we've passed input check, proceed with the create directly + + // validate zod schema if any + this.validateCreateInputSchema(this.model, args.data); + + // make a create args only containing data and ID selection + const createArgs: any = { data: args.data, select: this.utils.makeIdSelection(this.model) }; + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`create\` ${this.model}: ${formatObject(createArgs)}`); + } + const result = await tx[this.model].create(createArgs); + + // filter the read-back data + return this.utils.readBack(tx, this.model, 'create', args, result); + } else { + // proceed with a complex create and collect post-write checks + const { result, postWriteChecks } = await this.doCreate(this.model, args, tx); + + // execute post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + // filter the read-back data + return this.utils.readBack(tx, this.model, 'create', origArgs, result); } - return dbOps.create(writeArgs); }); - const ids = this.utils.getEntityIds(this.model, result); - if (Object.keys(ids).length === 0) { - throw this.utils.unknownError(`unexpected error: create didn't return an id`); + if (error) { + throw error; + } else { + return result; + } + } + + // create with nested write + private async doCreate(model: string, args: any, db: Record) { + // record id fields involved in the nesting context + const idSelections: Array<{ path: FieldInfo[]; ids: string[] }> = []; + const pushIdFields = (model: string, context: NestedWriteVisitorContext) => { + const idFields = getIdFields(this.modelMeta, model); + idSelections.push({ + path: context.nestingPath.map((p) => p.field).filter((f): f is FieldInfo => !!f), + ids: idFields.map((f) => f.name), + }); + }; + + // create a string key that uniquely identifies an entity + const getEntityKey = (model: string, ids: any) => + `${upperCaseFirst(model)}#${Object.keys(ids) + .sort() + .map((f) => `${f}:${ids[f]?.toString()}`) + .join('_')}`; + + // record keys of entities that are connected instead of created + const connectedEntities = new Set(); + + // visit the create payload + const visitor = new NestedWriteVisitor(this.modelMeta, { + create: async (model, args, context) => { + this.validateCreateInputSchema(model, args); + pushIdFields(model, context); + }, + + createMany: async (model, args, context) => { + enumerate(args.data).forEach((item) => this.validateCreateInputSchema(model, item)); + pushIdFields(model, context); + }, + + connectOrCreate: async (model, args, context) => { + if (!args.where) { + throw this.utils.validationError(`'where' field is required for connectOrCreate`); + } + + this.validateCreateInputSchema(model, args.create); + + const existing = await this.utils.checkExistence(db, model, args.where); + if (existing) { + // connect case + if (context.field?.backLink) { + const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); + if (backLinkField?.isRelationOwner) { + // the target side of relation owns the relation, + // check if it's updatable + await this.utils.checkPolicyForUnique(model, args.where, 'update', db); + } + } + + if (context.parent.connect) { + // if the payload parent already has a "connect" clause, merge it + if (Array.isArray(context.parent.connect)) { + context.parent.connect.push(args.where); + } else { + context.parent.connect = [context.parent.connect, args.where]; + } + } else { + // otherwise, create a new "connect" clause + context.parent.connect = args.where; + } + // record the key of connected entities so we can avoid validating them later + connectedEntities.add(getEntityKey(model, existing)); + } else { + // create case + pushIdFields(model, context); + + // create a new "create" clause at the parent level + context.parent.create = args.create; + } + + // remove the connectOrCreate clause + delete context.parent['connectOrCreate']; + + // return false to prevent visiting the nested payload + return false; + }, + + connect: async (model, args, context) => { + if (!args || typeof args !== 'object' || Object.keys(args).length === 0) { + throw this.utils.validationError(`'connect' field must be an non-empty object`); + } + + if (context.field?.backLink) { + const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); + if (backLinkField?.isRelationOwner) { + // check existence + await this.utils.checkExistence(db, model, args, true); + + // the target side of relation owns the relation, + // check if it's updatable + await this.utils.checkPolicyForUnique(model, args, 'update', db); + } + } + }, + }); + + await visitor.visit(model, 'create', args); + + // build the final "select" clause including all nested ID fields + let select: any = undefined; + if (idSelections.length > 0) { + select = {}; + idSelections.forEach(({ path, ids }) => { + let curr = select; + for (const p of path) { + if (!curr[p.name]) { + curr[p.name] = { select: {} }; + } + curr = curr[p.name].select; + } + Object.assign(curr, ...ids.map((f) => ({ [f]: true }))); + }); + } + + // proceed with the create + const createArgs = { data: args.data, select }; + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`create\` ${model}: ${formatObject(createArgs)}`); } + const result = await db[model].create(createArgs); + + // post create policy check for the top-level and nested creates + const postCreateChecks = new Map(); + + // visit the create result and collect entities that need to be post-checked + const modelDataVisitor = new ModelDataVisitor(this.modelMeta); + modelDataVisitor.visit(model, result, (model, _data, scalarData) => { + const key = getEntityKey(model, scalarData); + // only check if entity is created, not connected + if (!connectedEntities.has(key) && !postCreateChecks.has(key)) { + postCreateChecks.set(key, { model, operation: 'create', uniqueFilter: scalarData }); + } + }); + + // return only the ids of the top-level entity + const ids = this.utils.getEntityIds(this.model, result); + return { result: ids, postWriteChecks: [...postCreateChecks.values()] }; + } + + // Checks if the given create payload has nested create or connect + private async hasNestedCreateOrConnect(args: any) { + let hasNestedCreateOrConnect = false; + + const visitor = new NestedWriteVisitor(this.modelMeta, { + async create(_model, _args, context) { + if (context.field) { + hasNestedCreateOrConnect = true; + return false; + } else { + return true; + } + }, + async connect() { + hasNestedCreateOrConnect = true; + return false; + }, + async connectOrCreate() { + hasNestedCreateOrConnect = true; + return false; + }, + async createMany() { + hasNestedCreateOrConnect = true; + return false; + }, + }); + + await visitor.visit(this.model, 'create', args); + return hasNestedCreateOrConnect; + } - return this.checkReadback(origArgs, ids, 'create', 'create'); + // Validates the given create payload against Zod schema if any + private validateCreateInputSchema(model: string, data: any) { + const schema = this.utils.getZodSchema(model, 'create'); + if (schema) { + const parseResult = schema.safeParse(data); + if (!parseResult.success) { + throw this.utils.deniedByPolicy( + model, + 'create', + `input failed validation: ${fromZodError(parseResult.error)}`, + CrudFailureReason.DATA_VALIDATION_VIOLATION + ); + } + } } - async createMany(args: any, skipDuplicates?: boolean) { + async createMany(args: { data: any; skipDuplicates?: boolean }) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); } if (!args.data) { - throw prismaClientValidationError(this.prisma, 'data field is required and must be an array'); + throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.tryReject('create'); + this.utils.tryReject(this.model, 'create'); args = this.utils.clone(args); - // use a transaction to wrap the write so it can be reverted if any created - // entity fails access policies - const result = await this.utils.processWrite(this.model, 'create', args, (dbOps, writeArgs) => { - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`createMany\` ${this.model}: ${formatObject(writeArgs)}`); + // do static input validation and check if post-create checks are needed + let needPostCreateCheck = false; + for (const item of enumerate(args.data)) { + const inputCheck = this.utils.checkInputGuard(this.model, item, 'create'); + if (inputCheck === false) { + throw this.utils.deniedByPolicy(this.model, 'create'); + } else if (inputCheck === true) { + this.validateCreateInputSchema(this.model, item); + } else if (inputCheck === undefined) { + // static policy check is not possible, need to do post-create check + needPostCreateCheck = true; + break; } - return dbOps.createMany(writeArgs, skipDuplicates); - }); + } - return result as BatchResult; + if (!needPostCreateCheck) { + return this.modelClient.createMany(args); + } else { + // create entities in a transaction with post-create checks + return this.transaction(async (tx) => { + const { result, postWriteChecks } = await this.doCreateMany(this.model, args, tx); + // post-create check + await this.runPostWriteChecks(postWriteChecks, tx); + return result; + }); + } } + private async doCreateMany( + model: string, + args: { data: any; skipDuplicates?: boolean }, + db: Record + ) { + // We can't call the native "createMany" because we can't get back what was created + // for post-create checks. Instead, do a "create" for each item and collect the results. + + let createResult = await Promise.all( + enumerate(args.data).map(async (item) => { + if (args.skipDuplicates) { + // check unique constraint conflicts + // we can't rely on try/catch/ignore constraint violation error: https://github.com/prisma/prisma/issues/20496 + // TODO: for simple cases we should be able to translate it to an `upsert` with empty `update` payload + + // for each unique constraint, check if the input item has all fields set, and if so, check if + // an entity already exists, and ignore accordingly + const uniqueConstraints = this.utils.getUniqueConstraints(model); + for (const constraint of Object.values(uniqueConstraints)) { + if (constraint.fields.every((f) => item[f] !== undefined)) { + const uniqueFilter = constraint.fields.reduce((acc, f) => ({ ...acc, [f]: item[f] }), {}); + const existing = await this.utils.checkExistence(db, model, uniqueFilter); + if (existing) { + if (this.shouldLogQuery) { + this.logger.info(`[policy] skipping duplicate ${formatObject(item)}`); + } + return undefined; + } + } + } + } + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`create\` ${model}: ${formatObject(item)}`); + } + return await db[model].create({ select: this.utils.makeIdSelection(model), data: item }); + }) + ); + + // filter undefined values due to skipDuplicates + createResult = createResult.filter((p) => !!p); + + return { + result: { count: createResult.length }, + postWriteChecks: createResult.map((item) => ({ + model, + operation: 'create' as PolicyOperationKind, + uniqueFilter: item, + })), + }; + } + + //#endregion + + //#region Update & Upsert + + // "update" and "upsert" work against unique entity, so we actively rejects the request if the + // entity fails policy check + // + // "updateMany" works against a set of entities, entities not passing policy check are silently + // ignored + async update(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); @@ -166,25 +509,324 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.tryReject('update'); + const { result, error } = await this.transaction(async (tx) => { + // proceed with nested writes and collect post-write checks + const { result, postWriteChecks } = await this.doUpdate(args, tx); - const origArgs = args; + // post-write check + await this.runPostWriteChecks(postWriteChecks, tx); + + // filter the read-back data + return this.utils.readBack(tx, this.model, 'update', args, result); + }); + + if (error) { + throw error; + } else { + return result; + } + } + + private async doUpdate(args: any, db: Record) { args = this.utils.clone(args); - // use a transaction to wrap the write so it can be reverted if any nested - // create fails access policies - const result: any = await this.utils.processWrite(this.model, 'update', args, (dbOps, writeArgs) => { - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`update\` ${this.model}: ${formatObject(writeArgs)}`); + // collected post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + + // registers a post-update check task + const _registerPostUpdateCheck = async (model: string, where: any, db: Record) => { + // both "post-update" rules and Zod schemas require a post-update check + if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { + // select pre-update field values + let preValue: any; + const preValueSelect = this.utils.getPreValueSelect(model); + if (preValueSelect && Object.keys(preValueSelect).length > 0) { + preValue = await db[model].findFirst({ where, select: preValueSelect }); + } + postWriteChecks.push({ model, operation: 'postUpdate', uniqueFilter: where, preValue }); + } + }; + + // We can't let the native "update" to handle nested "create" because we can't get back what + // was created for doing post-update checks. + // Instead, handle nested create inside update as an atomic operation that creates an entire + // subtree (containing nested creates/connects) + + const _create = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + db: Record + ) => { + let createData = args; + if (context.field?.backLink) { + // handles the connection to upstream entity + const reversedQuery = await this.utils.buildReversedQuery(context); + if (reversedQuery[context.field.backLink]) { + // the built reverse query contains a condition for the backlink field, build a "connect" with it + createData = { + ...createData, + [context.field.backLink]: { + connect: reversedQuery[context.field.backLink], + }, + }; + } else { + // otherwise, the reverse query is translated to foreign key setting, merge it to the create data + createData = { + ...createData, + ...reversedQuery, + }; + } + } + + // proceed with the create and collect post-create checks + const { postWriteChecks: checks } = await this.doCreate(model, { data: createData }, db); + postWriteChecks.push(...checks); + }; + + const _createMany = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + db: Record + ) => { + if (context.field?.backLink) { + // handles the connection to upstream entity + const reversedQuery = await this.utils.buildReversedQuery(context); + for (const item of enumerate(args.data)) { + Object.assign(item, reversedQuery); + } + } + // proceed with the create and collect post-create checks + const { postWriteChecks: checks } = await this.doCreateMany(model, args, db); + postWriteChecks.push(...checks); + }; + + const _connectDisconnect = async ( + model: string, + args: any, + context: NestedWriteVisitorContext, + db: Record + ) => { + if (context.field?.backLink) { + const backLinkField = this.utils.getModelField(model, context.field.backLink); + if (backLinkField.isRelationOwner) { + // update happens on the related model, require updatable + await this.utils.checkPolicyForUnique(model, args, 'update', db); + + // register post-update check + await _registerPostUpdateCheck(model, args, db); + } } - return dbOps.update(writeArgs); + }; + + // visit nested writes + const visitor = new NestedWriteVisitor(this.modelMeta, { + update: async (model, args, context) => { + // build a unique query including upstream conditions + const uniqueFilter = await this.utils.buildReversedQuery(context); + + // handle not-found + const existing = await this.utils.checkExistence(db, model, uniqueFilter, true); + + // check if the update actually writes to this model + let thisModelUpdate = false; + const updatePayload: any = (args as any).data ?? args; + if (updatePayload) { + for (const key of Object.keys(updatePayload)) { + const field = resolveField(this.modelMeta, model, key); + if (field) { + if (!field.isDataModel) { + // scalar field, require this model to be updatable + thisModelUpdate = true; + break; + } else if (field.isRelationOwner) { + // relation is being updated and this model owns foreign key, require updatable + thisModelUpdate = true; + break; + } + } + } + } + + if (thisModelUpdate) { + this.utils.tryReject(this.model, 'update'); + + // check pre-update guard + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); + + // handles the case where id fields are updated + const ids = this.utils.clone(existing); + for (const key of Object.keys(existing)) { + const updateValue = (args as any).data ? (args as any).data[key] : (args as any)[key]; + if ( + typeof updateValue === 'string' || + typeof updateValue === 'number' || + typeof updateValue === 'bigint' + ) { + ids[key] = updateValue; + } + } + + // register post-update check + await _registerPostUpdateCheck(model, ids, db); + } + }, + + updateMany: async (model, args, context) => { + // injects auth guard into where clause + await this.utils.injectAuthGuard(args, model, 'update'); + + // prepare for post-update check + if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { + let select = this.utils.makeIdSelection(model); + const preValueSelect = this.utils.getPreValueSelect(model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + const reversedQuery = await this.utils.buildReversedQuery(context); + const currentSetQuery = { select, where: reversedQuery }; + await this.utils.injectAuthGuard(currentSetQuery, model, 'read'); + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(currentSetQuery)}`); + } + const currentSet = await db[model].findMany(currentSetQuery); + + postWriteChecks.push( + ...currentSet.map((preValue) => ({ + model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: preValue, + preValue: preValueSelect ? preValue : undefined, + })) + ); + } + }, + + create: async (model, args, context) => { + // process the entire create subtree separately + await _create(model, args, context, db); + + // remove it from the update payload + delete context.parent.create; + + // don't visit payload + return false; + }, + + createMany: async (model, args, context) => { + // process createMany separately + await _createMany(model, args, context, db); + + // remove it from the update payload + delete context.parent.createMany; + + // don't visit payload + return false; + }, + + upsert: async (model, args, context) => { + // build a unique query including upstream conditions + const uniqueFilter = await this.utils.buildReversedQuery(context); + + // branch based on if the update target exists + const existing = await this.utils.checkExistence(db, model, uniqueFilter); + if (existing) { + // update case + + // check pre-update guard + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); + + // register post-update check + await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + + // convert upsert to update + context.parent.update = { where: args.where, data: args.update }; + delete context.parent.upsert; + + // continue visiting the new payload + return context.parent.update; + } else { + // create case + + // process the entire create subtree separately + await _create(model, args.create, context, db); + + // remove it from the update payload + delete context.parent.upsert; + + // don't visit payload + return false; + } + }, + + connect: async (model, args, context) => _connectDisconnect(model, args, context, db), + + connectOrCreate: async (model, args, context) => { + // the where condition is already unique, so we can use it to check if the target exists + const existing = await this.utils.checkExistence(db, model, args.where); + if (existing) { + // connect + await _connectDisconnect(model, args.where, context, db); + } else { + // create + await _create(model, args.create, context, db); + } + }, + + disconnect: async (model, args, context) => _connectDisconnect(model, args, context, db), + + set: async (model, args, context) => { + // find the set of items to be replaced + const reversedQuery = await this.utils.buildReversedQuery(context); + const findCurrSetArgs = { + select: this.utils.makeIdSelection(model), + where: reversedQuery, + }; + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${model}:\n${formatObject(findCurrSetArgs)}`); + } + const currentSet = await db[model].findMany(findCurrSetArgs); + + // register current set for update (foreign key) + await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, db))); + + // proceed with connecting the new set + await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, db))); + }, + + delete: async (model, args, context) => { + // build a unique query including upstream conditions + const uniqueFilter = await this.utils.buildReversedQuery(context); + + // handle not-found + await this.utils.checkExistence(db, model, uniqueFilter, true); + + // check delete guard + await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db); + }, + + deleteMany: async (model, args, context) => { + // inject delete guard + const guard = await this.utils.getAuthGuard(model, 'delete'); + context.parent.deleteMany = this.utils.and(args, guard); + }, }); - const ids = this.utils.getEntityIds(this.model, result); - if (Object.keys(ids).length === 0) { - throw this.utils.unknownError(`unexpected error: update didn't return an id`); + await visitor.visit(this.model, 'update', args); + + // finally proceed with the update + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`update\` ${this.model}: ${formatObject(args)}`); } - return this.checkReadback(origArgs, ids, 'update', 'update'); + const result = await db[this.model].update({ + where: args.where, + data: args.data, + select: this.utils.makeIdSelection(this.model), + }); + + return { result, postWriteChecks }; } async updateMany(args: any) { @@ -195,20 +837,53 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } - await this.tryReject('update'); + await this.utils.tryReject(this.model, 'update'); args = this.utils.clone(args); - - // use a transaction to wrap the write so it can be reverted if any nested - // create fails access policies - const result = await this.utils.processWrite(this.model, 'updateMany', args, (dbOps, writeArgs) => { + await this.utils.injectAuthGuard(args, this.model, 'update'); + + if (this.utils.hasAuthGuard(this.model, 'postUpdate') || this.utils.getZodSchema(this.model)) { + // use a transaction to do post-update checks + const postWriteChecks: PostWriteCheckRecord[] = []; + return this.transaction(async (tx) => { + // collect pre-update values + let select = this.utils.makeIdSelection(this.model); + const preValueSelect = this.utils.getPreValueSelect(this.model); + if (preValueSelect) { + select = { ...select, ...preValueSelect }; + } + const currentSetQuery = { select, where: args.where }; + await this.utils.injectAuthGuard(currentSetQuery, this.model, 'read'); + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findMany\` ${this.model}: ${formatObject(currentSetQuery)}`); + } + const currentSet = await tx[this.model].findMany(currentSetQuery); + + postWriteChecks.push( + ...currentSet.map((preValue) => ({ + model: this.model, + operation: 'postUpdate' as PolicyOperationKind, + uniqueFilter: this.utils.getEntityIds(this.model, preValue), + preValue: preValueSelect ? preValue : undefined, + })) + ); + + // proceed with the update + const result = await tx[this.model].updateMany(args); + + // run post-write checks + await this.runPostWriteChecks(postWriteChecks, tx); + + return result; + }); + } else { + // proceed without a transaction if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`updateMany\` ${this.model}: ${formatObject(writeArgs)}`); + this.logger.info(`[policy] \`updateMany\` ${this.model}: ${formatObject(args)}`); } - return dbOps.updateMany(writeArgs); - }); - - return result as BatchResult; + return this.modelClient.updateMany(args); + } } async upsert(args: any) { @@ -225,29 +900,43 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'update field is required in query argument'); } - const origArgs = args; - args = this.utils.clone(args); - - await this.tryReject('create'); - await this.tryReject('update'); - - // use a transaction to wrap the write so it can be reverted if any nested - // create fails access policies - const result: any = await this.utils.processWrite(this.model, 'upsert', args, (dbOps, writeArgs) => { - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`upsert\` ${this.model}: ${formatObject(writeArgs)}`); + await this.utils.tryReject(this.model, 'create'); + await this.utils.tryReject(this.model, 'update'); + + // We can call the native "upsert" because we can't tell if an entity was created or updated + // for doing post-write check accordingly. Instead, decompose it into create or update. + + const { result, error } = await this.transaction(async (tx) => { + const { where, create, update, ...rest } = args; + const existing = await this.utils.checkExistence(tx, this.model, args.where); + + if (existing) { + // update case + const { result, postWriteChecks } = await this.doUpdate({ where, data: update, ...rest }, tx); + await this.runPostWriteChecks(postWriteChecks, tx); + return this.utils.readBack(tx, this.model, 'update', args, result); + } else { + // create case + const { result, postWriteChecks } = await this.doCreate(this.model, { data: create, ...rest }, tx); + await this.runPostWriteChecks(postWriteChecks, tx); + return this.utils.readBack(tx, this.model, 'create', args, result); } - return dbOps.upsert(writeArgs); }); - const ids = this.utils.getEntityIds(this.model, result); - if (Object.keys(ids).length === 0) { - throw this.utils.unknownError(`unexpected error: upsert didn't return an id`); + if (error) { + throw error; + } else { + return result; } - - return this.checkReadback(origArgs, ids, 'upsert', 'update'); } + //#endregion + + //#region Delete + + // "delete" works against a single entity, and is rejected if the entity fails policy check. + // "deleteMany" works against a set of entities, entities that fail policy check are filtered out. + async delete(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); @@ -256,41 +945,38 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'where field is required in query argument'); } - await this.tryReject('delete'); + await this.utils.tryReject(this.model, 'delete'); - // ensures the item under deletion passes policy check - await this.utils.checkPolicyForFilter(this.model, args.where, 'delete', this.prisma); + const { result, error } = await this.transaction(async (tx) => { + // do a read-back before delete + const r = await this.utils.readBack(tx, this.model, 'delete', args, args.where); + const error = r.error; + const read = r.result; - // read the entity under deletion with respect to read policies - let readResult: any; - try { - const items = await this.utils.readWithCheck(this.model, args); - readResult = items[0]; - } catch (err) { - // not readable - readResult = undefined; - } + // check existence + await this.utils.checkExistence(tx, this.model, args.where, true); - // conduct the deletion - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`delete\` ${this.model}:\n${formatObject(args)}`); - } - await this.modelClient.delete(args); + // inject delete guard + await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx); + + // proceed with the deletion + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`delete\` ${this.model}:\n${formatObject(args)}`); + } + await tx[this.model].delete(args); + + return { result: read, error }; + }); - if (!readResult) { - throw this.utils.deniedByPolicy( - this.model, - 'delete', - 'result is not allowed to be read back', - CrudFailureReason.RESULT_NOT_READABLE - ); + if (error) { + throw error; } else { - return readResult; + return result; } } async deleteMany(args: any) { - await this.tryReject('delete'); + await this.utils.tryReject(this.model, 'delete'); // inject policy conditions args = args ?? {}; @@ -298,23 +984,25 @@ export class PolicyProxyHandler implements Pr // conduct the deletion if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); + this.logger.info(`[policy] \`deleteMany\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.deleteMany(args); } + //#endregion + + //#region Aggregation + async aggregate(args: any) { if (!args) { throw prismaClientValidationError(this.prisma, 'query argument is required'); } - await this.tryReject('read'); - // inject policy conditions await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); + this.logger.info(`[policy] \`aggregate\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.aggregate(args); } @@ -324,60 +1012,50 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'query argument is required'); } - await this.tryReject('read'); - // inject policy conditions await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); + this.logger.info(`[policy] \`groupBy\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.groupBy(args); } async count(args: any) { - await this.tryReject('read'); - // inject policy conditions args = args ?? {}; await this.utils.injectAuthGuard(args, this.model, 'read'); if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`count\` ${this.model}:\n${formatObject(args)}`); + this.logger.info(`[policy] \`count\` ${this.model}:\n${formatObject(args)}`); } return this.modelClient.count(args); } - tryReject(operation: PolicyOperationKind) { - const guard = this.utils.getAuthGuard(this.model, operation); - if (guard === false) { - throw this.utils.deniedByPolicy(this.model, operation); - } + //#endregion + + //#region Utils + + private get shouldLogQuery() { + return !!this.logPrismaQuery && this.logger.enabled('info'); } - private async checkReadback( - origArgs: any, - ids: Record, - action: string, - operation: PolicyOperationKind - ) { - const readArgs = { select: origArgs.select, include: origArgs.include, where: ids }; - const result = await this.utils.readWithCheck(this.model, readArgs); - if (result.length === 0) { - this.logger.info(`${action} result cannot be read back`); - throw this.utils.deniedByPolicy( - this.model, - operation, - 'result is not allowed to be read back', - CrudFailureReason.RESULT_NOT_READABLE - ); - } else if (result.length > 1) { - throw this.utils.unknownError('write unexpected resulted in multiple readback entities'); - } - return result[0]; + private transaction(action: (tx: Record) => Promise) { + if (this.prisma[PRISMA_TX_FLAG]) { + // already in transaction, don't nest + return action(this.prisma); + } else { + return this.prisma.$transaction((tx) => action(tx)); + } } - private get shouldLogQuery() { - return this.logPrismaQuery && this.logger.enabled('info'); + private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record) { + await Promise.all( + postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) => + this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, preValue) + ) + ); } + + //#endregion } diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 1f45a906f..e83c77454 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1,31 +1,15 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { createId } from '@paralleldrive/cuid2'; import deepcopy from 'deepcopy'; import { lowerCaseFirst } from 'lower-case-first'; -import pluralize from 'pluralize'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; -import { - AUXILIARY_FIELDS, - CrudFailureReason, - GUARD_FIELD_NAME, - PRISIMA_TX_FLAG, - PrismaErrorCode, - TRANSACTION_FIELD_NAME, -} from '../../constants'; -import { - AuthUser, - DbClientContract, - DbOperations, - FieldInfo, - PolicyOperationKind, - PrismaWriteActionType, -} from '../../types'; +import { AUXILIARY_FIELDS, CrudFailureReason, GUARD_FIELD_NAME, PrismaErrorCode } from '../../constants'; +import { AuthUser, DbClientContract, DbOperations, FieldInfo, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import { getFields, resolveField } from '../model-meta'; -import { NestedWriteVisitor, type NestedWriteVisitorContext } from '../nested-write-vistor'; -import type { ModelMeta, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; +import { NestedWriteVisitorContext } from '../nested-write-vistor'; +import type { InputCheckFunc, ModelMeta, PolicyDef, PolicyFunc, ZodSchemas } from '../types'; import { enumerate, formatObject, @@ -33,6 +17,7 @@ import { getModelFields, prismaClientKnownRequestError, prismaClientUnknownRequestError, + prismaClientValidationError, } from '../utils'; import { Logger } from './logger'; @@ -50,15 +35,19 @@ export class PolicyUtil { private readonly policy: PolicyDef, private readonly zodSchemas: ZodSchemas | undefined, private readonly user?: AuthUser, - private readonly logPrismaQuery?: boolean + private readonly shouldLogQuery = false ) { this.logger = new Logger(db); } + //#region Logical operators + /** * Creates a conjunction of a list of query conditions. */ and(...conditions: (boolean | object)[]): any { + // TODO: reduction + if (conditions.includes(false)) { // always false return { [GUARD_FIELD_NAME]: false }; @@ -80,6 +69,8 @@ export class PolicyUtil { * Creates a disjunction of a list of query conditions. */ or(...conditions: (boolean | object)[]): any { + // TODO: reduction + if (conditions.includes(true)) { // always true return { [GUARD_FIELD_NAME]: true }; @@ -106,6 +97,10 @@ export class PolicyUtil { } } + //#endregion + + //# Auth guard + /** * Gets pregenerated authorization guard object for a given model and operation. * @@ -129,38 +124,74 @@ export class PolicyUtil { return provider({ user: this.user, preValue }); } - private hasValidation(model: string): boolean { - return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; + /** + * Checks if the given model has a policy guard for the given operation. + */ + hasAuthGuard(model: string, operation: PolicyOperationKind): boolean { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + return false; + } + const provider: PolicyFunc | boolean | undefined = guard[operation]; + return typeof provider !== 'boolean' || provider !== true; } - private async getPreValueSelect(model: string): Promise { + /** + * Checks model creation policy based on static analysis to the input args. + * + * @returns boolean if static analysis is enough to determine the result, undefined if not + */ + checkInputGuard(model: string, args: any, operation: 'create'): boolean | undefined { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { - throw this.unknownError(`unable to load policy guard for ${model}`); + return undefined; + } + + const provider: InputCheckFunc | boolean | undefined = guard[`${operation}_input` as const]; + + if (typeof provider === 'boolean') { + return provider; + } + + if (!provider) { + return undefined; } - return guard.preValueSelect; - } - private getModelSchema(model: string) { - return this.hasValidation(model) && this.zodSchemas?.models?.[`${upperCaseFirst(model)}Schema`]; + return provider(args, { user: this.user }); } /** * Injects model auth guard as where clause. */ async injectAuthGuard(args: any, model: string, operation: PolicyOperationKind) { + const guard = this.getAuthGuard(model, operation); + if (guard === false) { + // use OR with 0 filters to represent filtering out everything + // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals + args.where = { OR: [] }; + return false; + } + if (args.where) { // inject into relation fields: // to-many: some/none/every // to-one: direct-conditions/is/isNot - await this.injectGuardForFields(model, args.where, operation); + await this.injectGuardForRelationFields(model, args.where, operation); } - const guard = this.getAuthGuard(model, operation); - args.where = this.and(args.where, guard); + const combined = this.and(args.where, guard); + if (combined !== undefined) { + args.where = combined; + } else { + // use AND with 0 filters to represent no filtering + // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals + args.where = { AND: [] }; + } + + return true; } - async injectGuardForFields(model: string, payload: any, operation: PolicyOperationKind) { + private async injectGuardForRelationFields(model: string, payload: any, operation: PolicyOperationKind) { for (const [field, subPayload] of Object.entries(payload)) { if (!subPayload) { continue; @@ -179,19 +210,19 @@ export class PolicyUtil { } } - async injectGuardForToManyField( + private async injectGuardForToManyField( fieldInfo: FieldInfo, payload: { some?: any; every?: any; none?: any }, operation: PolicyOperationKind ) { const guard = this.getAuthGuard(fieldInfo.type, operation); if (payload.some) { - await this.injectGuardForFields(fieldInfo.type, payload.some, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload.some, operation); // turn "some" into: { some: { AND: [guard, payload.some] } } payload.some = this.and(payload.some, guard); } if (payload.none) { - await this.injectGuardForFields(fieldInfo.type, payload.none, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload.none, operation); // turn none into: { none: { AND: [guard, payload.none] } } payload.none = this.and(payload.none, guard); } @@ -201,7 +232,7 @@ export class PolicyUtil { // ignore empty every clause Object.keys(payload.every).length > 0 ) { - await this.injectGuardForFields(fieldInfo.type, payload.every, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload.every, operation); // turn "every" into: { none: { AND: [guard, { NOT: payload.every }] } } if (!payload.none) { @@ -212,7 +243,7 @@ export class PolicyUtil { } } - async injectGuardForToOneField( + private async injectGuardForToOneField( fieldInfo: FieldInfo, payload: { is?: any; isNot?: any } & Record, operation: PolicyOperationKind @@ -220,18 +251,18 @@ export class PolicyUtil { const guard = this.getAuthGuard(fieldInfo.type, operation); if (payload.is || payload.isNot) { if (payload.is) { - await this.injectGuardForFields(fieldInfo.type, payload.is, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload.is, operation); // turn "is" into: { is: { AND: [ originalIs, guard ] } payload.is = this.and(payload.is, guard); } if (payload.isNot) { - await this.injectGuardForFields(fieldInfo.type, payload.isNot, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload.isNot, operation); // turn "isNot" into: { isNot: { AND: [ originalIsNot, { NOT: guard } ] } } payload.isNot = this.and(payload.isNot, this.not(guard)); delete payload.isNot; } } else { - await this.injectGuardForFields(fieldInfo.type, payload, operation); + await this.injectGuardForRelationFields(fieldInfo.type, payload, operation); // turn direct conditions into: { is: { AND: [ originalConditions, guard ] } } const combined = this.and(deepcopy(payload), guard); Object.keys(payload).forEach((key) => delete payload[key]); @@ -240,60 +271,118 @@ export class PolicyUtil { } /** - * Read model entities w.r.t the given query args. The result list - * are guaranteed to fully satisfy 'read' policy rules recursively. - * - * For to-many relations involved, items not satisfying policy are - * silently trimmed. For to-one relation, if relation data fails policy - * an error is thrown. + * Injects auth guard for read operations. */ - async readWithCheck(model: string, args: any): Promise { - args = this.clone(args); + async injectForRead(model: string, args: any) { + const injected: any = {}; + if (!(await this.injectAuthGuard(injected, model, 'read'))) { + return false; + } if (args.where) { - // query args will be used with findMany, so we need to - // translate unique constraint filters into a flat filter - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } - await this.flattenGeneratedUniqueField(model, args.where); + // inject into relation fields: + // to-many: some/none/every + // to-one: direct-conditions/is/isNot + await this.injectGuardForRelationFields(model, args.where, 'read'); } - await this.injectAuthGuard(args, model, 'read'); + if (injected.where && Object.keys(injected.where).length > 0) { + args.where = args.where ?? {}; + Object.assign(args.where, injected.where); + } // recursively inject read guard conditions into nested select, include, and _count const hoistedConditions = await this.injectNestedReadConditions(model, args); // the injection process may generate conditions that need to be hoisted to the toplevel, // if so, merge it with the existing where - if (hoistedConditions && Object.keys(hoistedConditions).length > 0) { - args.where = this.and(args.where, ...hoistedConditions); - } - - if (this.shouldLogQuery) { - this.logger.info(`[withPolicy] \`findMany\` ${model}:\n${formatObject(args)}`); + if (hoistedConditions.length > 0) { + args.where = args.where ?? {}; + Object.assign(args.where, ...hoistedConditions); } - const result: any[] = await this.db[model].findMany(args); - - this.postProcessForRead(result, args); - return result; + return true; } // flatten unique constraint filters - async flattenGeneratedUniqueField(model: string, args: any) { + private flattenGeneratedUniqueField(model: string, args: any) { // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } const uniqueConstraints = this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)]; if (uniqueConstraints && Object.keys(uniqueConstraints).length > 0) { for (const [field, value] of Object.entries(args)) { - if (uniqueConstraints[field] && typeof value === 'object') { + if ( + uniqueConstraints[field] && + uniqueConstraints[field].fields.length > 1 && + typeof value === 'object' + ) { + // multi-field unique constraint, flatten it + delete args[field]; for (const [f, v] of Object.entries(value)) { args[f] = v; } - delete args[field]; } } } } + /** + * Gets unique constraints for the given model. + */ + getUniqueConstraints(model: string) { + return this.modelMeta.uniqueConstraints?.[lowerCaseFirst(model)] ?? {}; + } + + /** + * Builds a reversed query for the given nested path. + */ + async buildReversedQuery(context: NestedWriteVisitorContext) { + let result, currQuery: any; + let currField: FieldInfo | undefined; + + for (let i = context.nestingPath.length - 1; i >= 0; i--) { + const { field, model, where } = context.nestingPath[i]; + + // never modify the original where because it's shared in the structure + const visitWhere = { ...where }; + if (model && where) { + // make sure composite unique condition is flattened + this.flattenGeneratedUniqueField(model, visitWhere); + } + + if (!result) { + // first segment (bottom), just use its where clause + result = currQuery = { ...visitWhere }; + currField = field; + } else { + if (!currField) { + throw this.unknownError(`missing field in nested path`); + } + if (!currField.backLink) { + throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); + } + const backLinkField = this.getModelField(currField.type, currField.backLink); + if (backLinkField?.isArray) { + // many-side of relationship, wrap with "some" query + currQuery[currField.backLink] = { some: { ...visitWhere } }; + } else { + if (where && backLinkField.isRelationOwner && backLinkField.foreignKeyMapping) { + for (const [r, fk] of Object.entries(backLinkField.foreignKeyMapping)) { + currQuery[fk] = visitWhere[r]; + } + if (i > 0) { + currQuery[currField.backLink] = {}; + } + } else { + currQuery[currField.backLink] = { ...visitWhere }; + } + } + currQuery = currQuery[currField.backLink]; + currField = field; + } + } + return result; + } + private async injectNestedReadConditions(model: string, args: any): Promise { const injectTarget = args.select ?? args.include; if (!injectTarget) { @@ -379,393 +468,146 @@ export class PolicyUtil { } /** - * Post processing checks for read model entities. Validates to-one relations - * (which can't be trimmed at query time) and removes fields that should be - * omitted. - */ - private postProcessForRead(data: any, args: any) { - for (const entityData of enumerate(data)) { - if (typeof entityData !== 'object' || !entityData) { - return; - } - - // strip auxiliary fields - for (const auxField of AUXILIARY_FIELDS) { - if (auxField in entityData) { - delete entityData[auxField]; - } - } - - const injectTarget = args.select ?? args.include; - if (!injectTarget) { - return; - } - - // recurse into nested entities - for (const field of Object.keys(injectTarget)) { - const fieldData = entityData[field]; - if (typeof fieldData !== 'object' || !fieldData) { - continue; - } - this.postProcessForRead(fieldData, injectTarget[field]); - } - } - } - - /** - * Process Prisma write actions. + * Given a model and a unique filter, checks the operation is allowed by policies and field validations. + * Rejects with an error if not allowed. */ - async processWrite( + async checkPolicyForUnique( model: string, - action: PrismaWriteActionType, - args: any, - writeAction: (dbOps: DbOperations, writeArgs: any) => Promise + uniqueFilter: any, + operation: PolicyOperationKind, + db: Record, + preValue?: any ) { - // record model types for which new entities are created - // so we can post-check if they satisfy 'create' policies - const createdModels = new Set(); - - // record model entities that are updated, together with their - // values before update, so we can post-check if they satisfy - // model => { ids, entity value } - const updatedModels = new Map; value: any }>>(); - - function addUpdatedEntity(model: string, ids: Record, entity: any) { - let modelEntities = updatedModels.get(model); - if (!modelEntities) { - modelEntities = []; - updatedModels.set(model, modelEntities); - } - modelEntities.push({ ids, value: entity }); + const guard = this.getAuthGuard(model, operation, preValue); + if (guard === false) { + throw this.deniedByPolicy(model, operation, `entity ${formatObject(uniqueFilter)} failed policy check`); } - const idFields = this.getIdFields(model); - if (args.select) { - // make sure id fields are selected, we need it to - // read back the updated entity - for (const idField of idFields) { - if (!args.select[idField.name]) { - args.select[idField.name] = true; - } - } - } + // Zod schema is to be checked for "create" and "postUpdate" + const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - // use a transaction to conduct write, so in case any create or nested create - // fails access policies, we can roll back the entire operation - const transactionId = createId(); - - // args processor for create - const processCreate = async (model: string, args: any) => { - const guard = this.getAuthGuard(model, 'create'); - const schema = this.getModelSchema(model); - if (guard === false) { - throw this.deniedByPolicy(model, 'create'); - } else if (guard !== true || schema) { - // mark the create with a transaction tag so we can check them later - args[TRANSACTION_FIELD_NAME] = `${transactionId}:create`; - createdModels.add(model); - } - }; - - // build a reversed query for fetching entities affected by nested updates - const buildReversedQuery = async (context: NestedWriteVisitorContext) => { - let result, currQuery: any; - let currField: FieldInfo | undefined; - - for (let i = context.nestingPath.length - 1; i >= 0; i--) { - const { field, model, where, unique } = context.nestingPath[i]; + if (guard === true && !schema) { + // unconditionally allowed + return; + } - // never modify the original where because it's shared in the structure - const visitWhere = { ...where }; - if (model && where) { - // make sure composite unique condition is flattened - await this.flattenGeneratedUniqueField(model, visitWhere); - } + const select = schema + ? // need to validate against schema, need to fetch all fields + undefined + : // only fetch id fields + this.makeIdSelection(model); - if (!result) { - // first segment (bottom), just use its where clause - result = currQuery = { ...visitWhere }; - currField = field; - } else { - if (!currField) { - throw this.unknownError(`missing field in nested path`); - } - if (!currField.backLink) { - throw this.unknownError(`field ${currField.type}.${currField.name} doesn't have a backLink`); - } - const backLinkField = this.getModelField(currField.type, currField.backLink); - if (backLinkField?.isArray) { - // many-side of relationship, wrap with "some" query - currQuery[currField.backLink] = { some: { ...visitWhere } }; - } else { - currQuery[currField.backLink] = { ...visitWhere }; - } - currQuery = currQuery[currField.backLink]; - currField = field; - } + let where = this.clone(uniqueFilter); + // query args may have be of combined-id form, need to flatten it to call findFirst + this.flattenGeneratedUniqueField(model, where); - if (unique) { - // hit a unique filter, no need to traverse further up - break; - } - } - return result; - }; - - // args processor for update/upsert - const processUpdate = async (model: string, where: any, context: NestedWriteVisitorContext) => { - const preGuard = this.getAuthGuard(model, 'update'); - if (preGuard === false) { - throw this.deniedByPolicy(model, 'update'); - } else if (preGuard !== true) { - if (this.isToOneRelation(context.field)) { - // To-one relation field is complicated because there's no way to - // filter it during update (args doesn't carry a 'where' clause). - // - // We need to recursively walk up its hierarcy in the query args - // to construct a reversed query to identify the nested entity - // under update, and then check if it satisfies policy. - // - // E.g.: - // A - B - C - // - // update A with: - // { - // where: { id: 'aId' }, - // data: { - // b: { - // c: { value: 1 } - // } - // } - // } - // - // To check if the update to 'c' field is permitted, we - // reverse the query stack into a filter for C model, like: - // { - // where: { - // b: { a: { id: 'aId' } } - // } - // } - // , and with this we can filter out the C entity that's going - // to be nestedly updated, and check if it's allowed. - // - // The same logic applies to nested delete. - - const subQuery = await buildReversedQuery(context); - await this.checkPolicyForFilter(model, subQuery, 'update', this.db); - } else { - if (!where) { - throw this.unknownError(`Missing 'where' parameter`); - } - await this.checkPolicyForFilter(model, where, 'update', this.db); - } - } - - await preparePostUpdateCheck(model, context); - }; - - // args processor for updateMany - const processUpdateMany = async (model: string, args: any, context: NestedWriteVisitorContext) => { - const guard = this.getAuthGuard(model, 'update'); - if (guard === false) { - throw this.deniedByPolicy(model, 'update'); - } else if (guard !== true) { - // inject policy filter - await this.injectAuthGuard(args, model, 'update'); - } + // query with policy guard + if (guard !== true) { + where = this.and(where, guard); + } + const query = { select, where }; - await preparePostUpdateCheck(model, context); - }; - - // for models with post-update rules, we need to read and store - // entity values before the update for post-update check - const preparePostUpdateCheck = async (model: string, context: NestedWriteVisitorContext) => { - const postGuard = this.getAuthGuard(model, 'postUpdate'); - const schema = this.getModelSchema(model); - - // post-update check is needed if there's post-update rule or validation schema - if (postGuard !== true || schema) { - // fetch preValue selection (analyzed from the post-update rules) - const preValueSelect = await this.getPreValueSelect(model); - const filter = await buildReversedQuery(context); - - // query args will be used with findMany, so we need to - // translate unique constraint filters into a flat filter - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } - await this.flattenGeneratedUniqueField(model, filter); - - const idFields = this.getIdFields(model); - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const select: any = { ...preValueSelect }; - for (const idField of idFields) { - select[idField.name] = true; - } + if (this.shouldLogQuery) { + this.logger.info(`[policy] checking ${model} for ${operation}, \`findFirst\`:\n${formatObject(query)}`); + } + const result = await db[model].findFirst(query); + if (!result) { + throw this.deniedByPolicy(model, operation, `entity ${formatObject(uniqueFilter)} failed policy check`); + } - const query = { where: filter, select }; - if (this.shouldLogQuery) { - this.logger.info( - `[withPolicy] \`findMany\` ${model} for fetching pre-update entities:\n${formatObject(args)}` - ); - } - const entities = await this.db[model].findMany(query); - entities.forEach((entity) => { - addUpdatedEntity(model, this.getEntityIds(model, entity), entity); - }); - } - }; - - // args processor for delete - const processDelete = async (model: string, args: any, context: NestedWriteVisitorContext) => { - const guard = this.getAuthGuard(model, 'delete'); - if (guard === false) { - throw this.deniedByPolicy(model, 'delete'); - } else if (guard !== true) { - if (this.isToOneRelation(context.field)) { - // see comments in processUpdate - const subQuery = await buildReversedQuery(context); - await this.checkPolicyForFilter(model, subQuery, 'delete', this.db); - } else { - await this.checkPolicyForFilter(model, args, 'delete', this.db); - } - } - }; - - // process relation updates: connect, connectOrCreate, and disconnect - const processRelationUpdate = async (model: string, args: any, context: NestedWriteVisitorContext) => { - // CHECK ME: equire the entity being connected readable? - // await this.checkPolicyForFilter(model, args, 'read', this.db); - - if (context.field?.backLink) { - // fetch the backlink field of the model being connected - const backLinkField = resolveField(this.modelMeta, model, context.field.backLink); - if (backLinkField.isRelationOwner) { - // the target side of relation owns the relation, - // mark it as updated - await processUpdate(model, args, context); + if (schema) { + // TODO: push down schema check to the database + const parseResult = schema.safeParse(result); + if (!parseResult.success) { + const error = fromZodError(parseResult.error); + if (this.logger.enabled('info')) { + this.logger.info(`entity ${model} failed validation for operation ${operation}: ${error}`); } + throw this.deniedByPolicy( + model, + operation, + `entities ${JSON.stringify(uniqueFilter)} failed validation: [${error}]`, + CrudFailureReason.DATA_VALIDATION_VIOLATION + ); } - }; - - // use a visitor to process args before conducting the write action - const visitor = new NestedWriteVisitor(this.modelMeta, { - create: async (model, args) => { - await processCreate(model, args); - }, - - connectOrCreate: async (model, args, context) => { - if (args.create) { - await processCreate(model, args.create); - } - if (args.where) { - await processRelationUpdate(model, args.where, context); - } - }, - - connect: async (model, args, context) => { - await processRelationUpdate(model, args, context); - }, - - disconnect: async (model, args, context) => { - await processRelationUpdate(model, args, context); - }, - - update: async (model, args, context) => { - await processUpdate(model, args.where, context); - }, + } + } - updateMany: async (model, args, context) => { - await processUpdateMany(model, args, context); - }, + /** + * Tries rejecting a request based on static "false" policy. + */ + tryReject(model: string, operation: PolicyOperationKind) { + const guard = this.getAuthGuard(model, operation); + if (guard === false) { + throw this.deniedByPolicy(model, operation); + } + } - upsert: async (model, args, context) => { - if (args.create) { - await processCreate(model, args.create); - } + /** + * Checks if a model exists given a unique filter. + */ + async checkExistence( + db: Record, + model: string, + uniqueFilter: any, + throwIfNotFound = false + ): Promise { + uniqueFilter = this.clone(uniqueFilter); + this.flattenGeneratedUniqueField(model, uniqueFilter); - if (args.update) { - await processUpdate(model, args.where, context); - } - }, - - delete: async (model, args, context) => { - await processDelete(model, args, context); - }, - - // eslint-disable-next-line @typescript-eslint/no-unused-vars - deleteMany: async (model, args, _context) => { - const guard = this.getAuthGuard(model, 'delete'); - if (guard === false) { - throw this.deniedByPolicy(model, 'delete'); - } else if (guard !== true) { - if (args.where) { - args.where = this.and(args.where, guard); - } else { - const copy = deepcopy(args); - for (const key of Object.keys(args)) { - delete args[key]; - } - const combined = this.and(copy, guard); - Object.assign(args, combined); - } - } - }, + if (this.shouldLogQuery) { + this.logger.info(`[policy] checking ${model} existence, \`findFirst\`:\n${formatObject(uniqueFilter)}`); + } + const existing = await db[model].findFirst({ + where: uniqueFilter, + select: this.makeIdSelection(model), }); + if (!existing && throwIfNotFound) { + throw this.notFound(model); + } + return existing; + } - await visitor.visit(model, action, args); - - if (createdModels.size === 0 && updatedModels.size === 0) { - // no post-check needed, we can proceed with the write without transaction - return await writeAction(this.db[model], args); - } else { - return await this.transaction(this.db, async (tx) => { - // proceed with the update (with args processed) - const result = await writeAction(tx[model], args); - - if (createdModels.size > 0) { - // do post-check on created entities - await Promise.all( - [...createdModels].map((model) => - this.checkPolicyForFilter( - model, - { [TRANSACTION_FIELD_NAME]: `${transactionId}:create` }, - 'create', - tx - ) - ) - ); - } + /** + * Returns an entity given a unique filter with read policy checked. Reject if not readable. + */ + async readBack( + db: Record, + model: string, + operation: PolicyOperationKind, + selectInclude: { select?: any; include?: any }, + uniqueFilter: any + ): Promise<{ result: unknown; error?: Error }> { + uniqueFilter = this.clone(uniqueFilter); + this.flattenGeneratedUniqueField(model, uniqueFilter); + const readArgs = { select: selectInclude.select, include: selectInclude.include, where: uniqueFilter }; + const error = this.deniedByPolicy( + model, + operation, + 'result is not allowed to be read back', + CrudFailureReason.RESULT_NOT_READABLE + ); - if (updatedModels.size > 0) { - // do post-check on updated entities - await Promise.all( - [...updatedModels.entries()] - .map(([model, modelEntities]) => - modelEntities.map(async ({ ids, value: preValue }) => - this.checkPostUpdate(model, ids, tx, preValue) - ) - ) - .flat() - ); - } + const injectResult = await this.injectForRead(model, readArgs); + if (!injectResult) { + return { error, result: undefined }; + } - return result; - }); + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`findFirst\` ${model}:\n${formatObject(readArgs)}`); + } + const result = await db[model].findFirst(readArgs); + if (!result) { + return { error, result: undefined }; } - } - private getModelField(model: string, field: string) { - model = lowerCaseFirst(model); - return this.modelMeta.fields[model]?.[field]; + this.postProcessForRead(result); + return { result, error: undefined }; } - private transaction(db: DbClientContract, action: (tx: Record) => Promise) { - if (db[PRISIMA_TX_FLAG]) { - // already in transaction, don't nest - return action(db); - } else { - return db.$transaction((tx) => action(tx)); - } - } + //#endregion + + //#region Errors deniedByPolicy(model: string, operation: PolicyOperationKind, extra?: string, reason?: CrudFailureReason) { return prismaClientKnownRequestError( @@ -782,175 +624,103 @@ export class PolicyUtil { }); } + validationError(message: string) { + return prismaClientValidationError(this.db, message, { + clientVersion: getVersion(), + }); + } + unknownError(message: string) { return prismaClientUnknownRequestError(this.db, message, { clientVersion: getVersion(), }); } + //#endregion + + //#region Misc + /** - * Given a filter, check if applying access policy filtering will result - * in data being trimmed, and if so, throw an error. + * Gets field selection for fetching pre-update entity values for the given model. */ - async checkPolicyForFilter( - model: string, - filter: any, - operation: PolicyOperationKind, - db: Record - ) { - const guard = this.getAuthGuard(model, operation); - const schema = (operation === 'create' || operation === 'update') && this.getModelSchema(model); - - if (guard === true && !schema) { - // unconditionally allowed - return; + getPreValueSelect(model: string): object | undefined { + const guard = this.policy.guard[lowerCaseFirst(model)]; + if (!guard) { + throw this.unknownError(`unable to load policy guard for ${model}`); } + return guard.preValueSelect; + } - // if (this.logger.enabled('info')) { - // this.logger.info(`Checking policy for ${model}#${JSON.stringify(filter)} for ${operation}`); - // } - - const queryFilter = deepcopy(filter); + private hasFieldValidation(model: string): boolean { + return this.policy.validation?.[lowerCaseFirst(model)]?.hasValidation === true; + } - // query args will be used with findMany, so we need to - // translate unique constraint filters into a flat filter - // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } - await this.flattenGeneratedUniqueField(model, queryFilter); - - const countArgs = { where: queryFilter }; - // if (this.shouldLogQuery) { - // this.logger.info( - // `[withPolicy] \`count\` for policy check without guard:\n${formatObject(countArgs)}` - // ); - // } - const count = (await db[model].count(countArgs)) as number; - if (count === 0) { - // there's nothing to filter out - return; + /** + * Gets Zod schema for the given model and access kind. + * + * @param kind If undefined, returns the full schema. + */ + getZodSchema(model: string, kind: 'create' | 'update' | undefined = undefined) { + if (!this.hasFieldValidation(model)) { + return undefined; } + const schemaKey = `${upperCaseFirst(model)}${kind ? upperCaseFirst(kind) : ''}Schema`; + return this.zodSchemas?.models?.[schemaKey]; + } - if (guard === false) { - // unconditionally denied - throw this.deniedByPolicy(model, operation, `${count} ${pluralize('entity', count)} failed policy check`); + /** + * Post processing checks and clean-up for read model entities. + */ + postProcessForRead(data: any) { + if (data === null || data === undefined) { + return; } - // build a query condition with policy injected - const guardedQuery = { where: this.and(queryFilter, guard) }; - - if (schema) { - // we've got schemas, so have to fetch entities and validate them - // if (this.shouldLogQuery) { - // this.logger.info( - // `[withPolicy] \`findMany\` for policy check with guard:\n${formatObject(countArgs)}` - // ); - // } - const entities = await db[model].findMany(guardedQuery); - if (entities.length < count) { - if (this.logger.enabled('info')) { - this.logger.info(`entity ${model} failed policy check for operation ${operation}`); - } - throw this.deniedByPolicy( - model, - operation, - `${count - entities.length} ${pluralize('entity', count - entities.length)} failed policy check` - ); + for (const entityData of enumerate(data)) { + if (typeof entityData !== 'object' || !entityData) { + return; } - // TODO: push down schema check to the database - const schemaCheckErrors = entities.map((entity) => schema.safeParse(entity)).filter((r) => !r.success); - if (schemaCheckErrors.length > 0) { - const error = schemaCheckErrors.map((r) => !r.success && fromZodError(r.error).message).join(', '); - if (this.logger.enabled('info')) { - this.logger.info(`entity ${model} failed schema check for operation ${operation}: ${error}`); - } - throw this.deniedByPolicy( - model, - operation, - `entities failed schema check: [${error}]`, - CrudFailureReason.DATA_VALIDATION_VIOLATION - ); - } - } else { - // count entities with policy injected and see if any of them are filtered out - // if (this.shouldLogQuery) { - // this.logger.info( - // `[withPolicy] \`count\` for policy check with guard:\n${formatObject(guardedQuery)}` - // ); - // } - const guardedCount = (await db[model].count(guardedQuery)) as number; - if (guardedCount < count) { - if (this.logger.enabled('info')) { - this.logger.info(`entity ${model} failed policy check for operation ${operation}`); + // strip auxiliary fields + for (const auxField of AUXILIARY_FIELDS) { + if (auxField in entityData) { + delete entityData[auxField]; } - throw this.deniedByPolicy( - model, - operation, - `${count - guardedCount} ${pluralize('entity', count - guardedCount)} failed policy check` - ); } - } - } - - private async checkPostUpdate( - model: string, - ids: Record, - db: Record, - preValue: any - ) { - // if (this.logger.enabled('info')) { - // this.logger.info(`Checking post-update policy for ${model}#${ids}, preValue: ${formatObject(preValue)}`); - // } - - const guard = this.getAuthGuard(model, 'postUpdate', preValue); - // build a query condition with policy injected - const guardedQuery = { where: this.and(ids, guard) }; - - // query with policy injected - const entity = await db[model].findFirst(guardedQuery); - - // see if we get fewer items with policy, if so, reject with an throw - if (!entity) { - if (this.logger.enabled('info')) { - this.logger.info(`entity ${model} failed policy check for operation postUpdate`); - } - throw this.deniedByPolicy(model, 'postUpdate'); - } - - // TODO: push down schema check to the database - const schema = this.getModelSchema(model); - if (schema) { - const schemaCheckResult = schema.safeParse(entity); - if (!schemaCheckResult.success) { - const error = fromZodError(schemaCheckResult.error).message; - if (this.logger.enabled('info')) { - this.logger.info(`entity ${model} failed schema check for operation postUpdate: ${error}`); + for (const fieldData of Object.values(entityData)) { + if (typeof fieldData !== 'object' || !fieldData) { + continue; } - throw this.deniedByPolicy(model, 'postUpdate', `entity failed schema check: ${error}`); + this.postProcessForRead(fieldData); } } } - private isToOneRelation(field: FieldInfo | undefined) { - return !!field && field.isDataModel && !field.isArray; + /** + * Gets information for a specific model field. + */ + getModelField(model: string, field: string) { + model = lowerCaseFirst(model); + return this.modelMeta.fields[model]?.[field]; } /** * Clones an object and makes sure it's not empty. */ - clone(value: unknown) { + clone(value: unknown): any { return value ? deepcopy(value) : {}; } /** - * Gets "id" field for a given model. + * Gets "id" fields for a given model. */ getIdFields(model: string) { return getIdFields(this.modelMeta, model, true); } /** - * Gets id field value from an entity. + * Gets id field values from an entity. */ getEntityIds(model: string, entityData: any) { const idFields = this.getIdFields(model); @@ -961,7 +731,13 @@ export class PolicyUtil { return result; } - private get shouldLogQuery() { - return this.logPrismaQuery && this.logger.enabled('info'); + /** + * Creates a selection object for id fields for the given model. + */ + makeIdSelection(model: string) { + const idFields = this.getIdFields(model); + return Object.assign({}, ...idFields.map((f) => ({ [f.name]: true }))); } + + //#endregion } diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index 65966c4c3..43cc36a30 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -1,6 +1,6 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { PRISIMA_TX_FLAG, PRISMA_PROXY_ENHANCER } from '../constants'; +import { PRISMA_TX_FLAG, PRISMA_PROXY_ENHANCER } from '../constants'; import { DbClientContract } from '../types'; import { ModelMeta } from './types'; @@ -71,7 +71,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { async findFirst(args: any): Promise { args = await this.preprocessArgs('findFirst', args); - const r = this.prisma[this.model].findFirst(args); + const r = await this.prisma[this.model].findFirst(args); return this.processResultEntity(r); } @@ -100,7 +100,7 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { async update(args: any): Promise { args = await this.preprocessArgs('update', args); - const r = this.prisma[this.model].update(args); + const r = await this.prisma[this.model].update(args); return this.processResultEntity(r); } @@ -111,13 +111,13 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { async upsert(args: any): Promise { args = await this.preprocessArgs('upsert', args); - const r = this.prisma[this.model].upsert(args); + const r = await this.prisma[this.model].upsert(args); return this.processResultEntity(r); } async delete(args: any): Promise { args = await this.preprocessArgs('delete', args); - const r = this.prisma[this.model].delete(args); + const r = await this.prisma[this.model].delete(args); return this.processResultEntity(r); } @@ -204,7 +204,7 @@ export function makeProxy( const txFunc = input; return $transaction.bind(target)((tx: any) => { const txProxy = makeProxy(tx, modelMeta, makeHandler, name + '$tx'); - txProxy[PRISIMA_TX_FLAG] = true; + txProxy[PRISMA_TX_FLAG] = true; return txFunc(txProxy); }, ...rest); }; diff --git a/packages/runtime/src/enhancements/types.ts b/packages/runtime/src/enhancements/types.ts index 8834e73c2..3de460126 100644 --- a/packages/runtime/src/enhancements/types.ts +++ b/packages/runtime/src/enhancements/types.ts @@ -1,3 +1,4 @@ +/* eslint-disable @typescript-eslint/no-explicit-any */ import { z } from 'zod'; import { FieldInfo, PolicyOperationKind, QueryContext } from '../types'; @@ -20,6 +21,11 @@ export type ModelMeta = { */ export type PolicyFunc = (context: QueryContext) => object; +/** + * Function for getting policy guard with a given context + */ +export type InputCheckFunc = (args: any, context: QueryContext) => boolean; + /** * Policy definition */ @@ -31,6 +37,8 @@ export type PolicyDef = { allowAll?: boolean; denyAll?: boolean; } & Partial> & { + create_input: InputCheckFunc; + } & { preValueSelect?: object; } >; diff --git a/packages/runtime/src/enhancements/utils.ts b/packages/runtime/src/enhancements/utils.ts index bbe85a5ce..4ba822f3c 100644 --- a/packages/runtime/src/enhancements/utils.ts +++ b/packages/runtime/src/enhancements/utils.ts @@ -38,7 +38,9 @@ export type Enumerable = T | Array; * Uniformly enumerates an array or scalar. */ export function enumerate(x: Enumerable) { - if (Array.isArray(x)) { + if (x === null || x === undefined) { + return []; + } else if (Array.isArray(x)) { return x; } else { return [x]; diff --git a/packages/runtime/src/error.ts b/packages/runtime/src/error.ts index 0c8119cb1..22422e7b8 100644 --- a/packages/runtime/src/error.ts +++ b/packages/runtime/src/error.ts @@ -2,13 +2,20 @@ export function isPrismaClientKnownRequestError( err: any ): err is { code: string; message: string; meta?: Record } { - return err.__proto__.constructor.name === 'PrismaClientKnownRequestError'; + return findConstructorName(err.__proto__, 'PrismaClientKnownRequestError'); } export function isPrismaClientUnknownRequestError(err: any): err is { message: string } { - return err.__proto__.constructor.name === 'PrismaClientUnknownRequestError'; + return findConstructorName(err.__proto__, 'PrismaClientUnknownRequestError'); } export function isPrismaClientValidationError(err: any): err is { message: string } { - return err.__proto__.constructor.name === 'PrismaClientValidationError'; + return findConstructorName(err.__proto__, 'PrismaClientValidationError'); +} + +function findConstructorName(proto: any, name: string): boolean { + if (!proto) { + return false; + } + return proto.constructor.name === name || findConstructorName(proto.__proto__, name); } diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index b8ecb0122..ced5bc699 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -100,10 +100,15 @@ export type FieldInfo = { * If the field is the owner side of a relation */ isRelationOwner: boolean; + + /** + * Mapping from foreign key field names to relation field names + */ + foreignKeyMapping?: Record; }; export type DbClientContract = Record & { - $transaction: (action: (tx: Record) => Promise) => Promise; + $transaction: (action: (tx: Record) => Promise, options?: unknown) => Promise; }; export const PrismaWriteActions = [ @@ -115,6 +120,7 @@ export const PrismaWriteActions = [ 'upsert', 'connect', 'disconnect', + 'set', 'delete', 'deleteMany', ] as const; diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 6758eb7d3..6102528a6 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -24,6 +24,7 @@ import { getLiteral, getPrismaClientImportSpec, GUARD_FIELD_NAME, + hasAttribute, hasValidationAttributes, PluginError, PluginOptions, @@ -35,7 +36,13 @@ import { import { streamAllContents } from 'langium'; import { lowerCaseFirst } from 'lower-case-first'; import path from 'path'; -import { FunctionDeclaration, SourceFile, VariableDeclarationKind } from 'ts-morph'; +import { + FunctionDeclaration, + SourceFile, + StatementStructures, + VariableDeclarationKind, + WriterFunction, +} from 'ts-morph'; import { name } from '.'; import { isFromStdlib } from '../../language-server/utils'; import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; @@ -192,7 +199,7 @@ export default class PolicyGenerator { } private hasFutureReference(expr: Expression) { - for (const node of streamAllContents(expr)) { + for (const node of this.allNodes(expr)) { if (isInvocationExpr(node) && node.function.ref?.name === 'future' && isFromStdlib(node.function.ref)) { return true; } @@ -209,6 +216,9 @@ export default class PolicyGenerator { for (const kind of ALL_OPERATION_KINDS) { if (policies[kind] === true || policies[kind] === false) { result[kind] = policies[kind]; + if (kind === 'create') { + result[kind + '_input'] = policies[kind]; + } continue; } @@ -233,9 +243,9 @@ export default class PolicyGenerator { continue; } - const func = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); + const guardFunc = this.generateQueryGuardFunction(sourceFile, model, kind, allows, denies); // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - result[kind] = func.getName()!; + result[kind] = guardFunc.getName()!; if (kind === 'postUpdate') { const preValueSelect = this.generatePreValueSelect(model, allows, denies); @@ -243,10 +253,45 @@ export default class PolicyGenerator { result['preValueSelect'] = preValueSelect; } } + + if (kind === 'create' && this.canCheckCreateBasedOnInput(model, allows, denies)) { + const inputCheckFunc = this.generateInputCheckFunction(sourceFile, model, kind, allows, denies); + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[kind + '_input'] = inputCheckFunc.getName()!; + } } return result; } + private canCheckCreateBasedOnInput(model: DataModel, allows: Expression[], denies: Expression[]) { + return [...allows, ...denies].every((rule) => { + return [...this.allNodes(rule)].every((expr) => { + if (isThisExpr(expr)) { + return false; + } + if (isReferenceExpr(expr)) { + if (isDataModel(expr.$resolvedType?.decl)) { + // if policy rules uses relation fields, + // we can't check based on create input + return false; + } + if ( + isDataModelField(expr.target.ref) && + expr.target.ref.$container === model && + hasAttribute(expr.target.ref, '@default') + ) { + // reference to field of current model + // if it has default value, we can't check + // based on create input + return false; + } + } + + return true; + }); + }); + } + // generates an object that can be used as the 'select' argument when fetching pre-update // entity value private generatePreValueSelect(model: DataModel, allows: Expression[], denies: Expression[]): object { @@ -289,7 +334,7 @@ export default class PolicyGenerator { }; for (const rule of [...allows, ...denies]) { - for (const expr of streamAllContents(rule).filter((node): node is Expression => isExpression(node))) { + for (const expr of [...this.allNodes(rule)].filter((node): node is Expression => isExpression(node))) { // only care about member access and reference expressions if (!isMemberAccessExpr(expr) && !isReferenceExpr(expr)) { continue; @@ -321,22 +366,11 @@ export default class PolicyGenerator { allows: Expression[], denies: Expression[] ): FunctionDeclaration { - const func = sourceFile - .addFunction({ - name: model.name + '_' + kind, - returnType: 'any', - parameters: [ - { - name: 'context', - type: 'QueryContext', - }, - ], - }) - .addBody(); + const statements: (string | WriterFunction | StatementStructures)[] = []; // check if any allow or deny rule contains 'auth()' invocation const hasAuthRef = [...denies, ...allows].some((rule) => - streamAllContents(rule).some((child) => isAuthInvocation(child)) + [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) ); if (hasAuthRef) { @@ -352,7 +386,7 @@ export default class PolicyGenerator { } // normalize user to null to avoid accidentally use undefined in filter - func.addStatements( + statements.push( `const user = hasAllFields(context.user, [${userIdFields .map((f) => "'" + f.name + "'") .join(', ')}]) ? context.user as any : null;` @@ -360,7 +394,7 @@ export default class PolicyGenerator { } const hasFieldAccess = [...denies, ...allows].some((rule) => - [rule, ...streamAllContents(rule)].some( + [...this.allNodes(rule)].some( (child) => // this.??? isThisExpr(child) || @@ -374,7 +408,7 @@ export default class PolicyGenerator { if (!hasFieldAccess) { // none of the rules reference model fields, we can compile down to a plain boolean // function in this case (so we can skip doing SQL queries when validating) - func.addStatements((writer) => { + statements.push((writer) => { const transformer = new TypeScriptExpressionTransformer({ context: ExpressionContext.AccessPolicy, isPostGuard: kind === 'postUpdate', @@ -396,7 +430,7 @@ export default class PolicyGenerator { writer.write('return false;'); }); } else { - func.addStatements((writer) => { + statements.push((writer) => { writer.write('return '); const exprWriter = new ExpressionWriter(writer, kind === 'postUpdate'); const writeDenies = () => { @@ -437,6 +471,109 @@ export default class PolicyGenerator { writer.write(';'); }); } + + const func = sourceFile.addFunction({ + name: model.name + '_' + kind, + returnType: 'any', + parameters: [ + { + name: 'context', + type: 'QueryContext', + }, + ], + statements, + }); + + return func; + } + + private generateInputCheckFunction( + sourceFile: SourceFile, + model: DataModel, + kind: 'create' | 'update', + allows: Expression[], + denies: Expression[] + ): FunctionDeclaration { + const statements: (string | WriterFunction | StatementStructures)[] = []; + + // check if any allow or deny rule contains 'auth()' invocation + const hasAuthRef = [...denies, ...allows].some((rule) => + [...this.allNodes(rule)].some((child) => isAuthInvocation(child)) + ); + + if (hasAuthRef) { + const userModel = model.$container.declarations.find( + (decl): decl is DataModel => isDataModel(decl) && decl.name === 'User' + ); + if (!userModel) { + throw new PluginError(name, 'User model not found'); + } + const userIdFields = getIdFields(userModel); + if (!userIdFields || userIdFields.length === 0) { + throw new PluginError(name, 'User model does not have an id field'); + } + + // normalize user to null to avoid accidentally use undefined in filter + statements.push( + `const user = hasAllFields(context.user, [${userIdFields + .map((f) => "'" + f.name + "'") + .join(', ')}]) ? context.user as any : null;` + ); + } + + statements.push((writer) => { + if (allows.length === 0) { + writer.write('return false;'); + return; + } + + const transformer = new TypeScriptExpressionTransformer({ + context: ExpressionContext.AccessPolicy, + fieldReferenceContext: 'input', + }); + + let expr = + denies.length > 0 + ? '!(' + + denies + .map((deny) => { + return transformer.transform(deny); + }) + .join(' || ') + + ')' + : undefined; + + const allowStmt = allows + .map((allow) => { + return transformer.transform(allow); + }) + .join(' || '); + + expr = expr ? `${expr} && (${allowStmt})` : allowStmt; + writer.write('return ' + expr); + }); + + const func = sourceFile.addFunction({ + name: model.name + '_' + kind + '_input', + returnType: 'boolean', + parameters: [ + { + name: 'input', + type: 'any', + }, + { + name: 'context', + type: 'QueryContext', + }, + ], + statements, + }); + return func; } + + private *allNodes(expr: Expression) { + yield expr; + yield* streamAllContents(expr); + } } diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index ce6d82785..8e488529a 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -2,8 +2,10 @@ import { ArrayExpr, DataModel, DataModelField, + isArrayExpr, isDataModel, isLiteralExpr, + isReferenceExpr, Model, ReferenceExpr, } from '@zenstackhq/language/ast'; @@ -68,6 +70,7 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter) writer.block(() => { for (const f of model.fields) { const backlink = getBackLink(f); + const fkMapping = generateForeignKeyMapping(f); writer.write(`${f.name}: { name: "${f.name}", type: "${ @@ -83,6 +86,7 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter) attributes: ${JSON.stringify(getFieldAttributes(f))}, backLink: ${backlink ? "'" + backlink.name + "'" : 'undefined'}, isRelationOwner: ${isRelationOwner(f, backlink)}, + foreignKeyMapping: ${fkMapping ? JSON.stringify(fkMapping) : 'undefined'} },`); } }); @@ -159,6 +163,8 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { function getUniqueConstraints(model: DataModel) { const constraints: Array<{ name: string; fields: string[] }> = []; + + // model-level constraints for (const attr of model.attributes.filter( (attr) => attr.decl.ref?.name === '@@unique' || attr.decl.ref?.name === '@@id' )) { @@ -175,6 +181,14 @@ function getUniqueConstraints(model: DataModel) { constraints.push({ name: constraintName, fields: fieldNames }); } } + + // field-level constraints + for (const field of model.fields) { + if (hasAttribute(field, '@id') || hasAttribute(field, '@unique')) { + constraints.push({ name: field.name, fields: [field.name] }); + } + } + return constraints; } @@ -205,3 +219,28 @@ function holdsForeignKey(field: DataModelField) { const fields = getAttributeArg(relation, 'fields'); return !!fields; } + +function generateForeignKeyMapping(field: DataModelField) { + const relation = field.attributes.find((attr) => attr.decl.ref?.name === '@relation'); + if (!relation) { + return undefined; + } + const fields = getAttributeArg(relation, 'fields'); + const references = getAttributeArg(relation, 'references'); + if (!isArrayExpr(fields) || !isArrayExpr(references) || fields.items.length !== references.items.length) { + return undefined; + } + + const fieldNames = fields.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined)); + const referenceNames = references.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined)); + + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const result: Record = {}; + referenceNames.forEach((name, i) => { + if (name) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + result[name] = fieldNames[i]!; + } + }); + return result; +} diff --git a/packages/schema/src/utils/typescript-expression-transformer.ts b/packages/schema/src/utils/typescript-expression-transformer.ts index 2629f3a2f..fb9ac41cf 100644 --- a/packages/schema/src/utils/typescript-expression-transformer.ts +++ b/packages/schema/src/utils/typescript-expression-transformer.ts @@ -271,7 +271,7 @@ export class TypeScriptExpressionTransformer { return `context.preValue?.${expr.target.ref.name}`; } else { return this.options?.fieldReferenceContext - ? `${this.options.fieldReferenceContext}.${expr.target.ref.name}` + ? `${this.options.fieldReferenceContext}?.${expr.target.ref.name}` : expr.target.ref.name; } } diff --git a/packages/server/src/api/rpc/index.ts b/packages/server/src/api/rpc/index.ts index 57a9437ef..b8615788e 100644 --- a/packages/server/src/api/rpc/index.ts +++ b/packages/server/src/api/rpc/index.ts @@ -15,6 +15,12 @@ import { logError, processEntityData, registerCustomSerializers } from '../utils registerCustomSerializers(); +const ERROR_STATUS_MAPPING: Record = { + [PrismaErrorCode.CONSTRAINED_FAILED]: 403, + [PrismaErrorCode.REQUIRED_CONNECTED_RECORD_NOT_FOUND]: 404, + [PrismaErrorCode.DEPEND_ON_RECORD_NOT_FOUND]: 404, +}; + /** * Prisma RPC style API request handler that mirrors the Prisma Client API */ @@ -149,33 +155,20 @@ class RequestHandler extends APIHandlerBase { } catch (err) { if (isPrismaClientKnownRequestError(err)) { logError(logger, err.code, err.message); - if (err.code === PrismaErrorCode.CONSTRAINED_FAILED) { - // rejected by policy - return { - status: 403, - body: { - error: { - prisma: true, - rejectedByPolicy: true, - code: err.code, - message: err.message, - reason: err.meta?.reason, - }, - }, - }; - } else { - return { - status: 400, - body: { - error: { - prisma: true, - code: err.code, - message: err.message, - reason: err.meta?.reason, - }, + const status = ERROR_STATUS_MAPPING[err.code] ?? 400; + const rejectedByPolicy = err.code === PrismaErrorCode.CONSTRAINED_FAILED ? true : undefined; + return { + status, + body: { + error: { + prisma: true, + rejectedByPolicy, + code: err.code, + message: err.message, + reason: err.meta?.reason, }, - }; - } + }, + }; } else if (isPrismaClientUnknownRequestError(err) || isPrismaClientValidationError(err)) { logError(logger, err.message); return { diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 562672bfc..ac00a2d27 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -210,7 +210,7 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { console.log('Compiling...'); run('npx tsc --init'); - // add genetated '.zenstack/zod' folder to typescript's search path, + // add generated '.zenstack/zod' folder to typescript's search path, // so that it can be resolved from symbolic-linked files const tsconfig = json.parse(fs.readFileSync(path.join(projectRoot, './tsconfig.json'), 'utf-8')); tsconfig.compilerOptions.paths = { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 4c87fa7a4..a2294dad3 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -345,9 +345,6 @@ importers: packages/runtime: dependencies: - '@paralleldrive/cuid2': - specifier: ^2.2.0 - version: 2.2.0 '@types/bcryptjs': specifier: ^2.4.2 version: 2.4.2 @@ -790,6 +787,9 @@ importers: decimal.js: specifier: ^10.4.2 version: 10.4.2 + pg: + specifier: ^8.11.1 + version: 8.11.1 sleep-promise: specifier: ^9.1.0 version: 9.1.0 @@ -806,6 +806,9 @@ importers: '@types/jest': specifier: ^29.5.0 version: 29.5.0 + '@types/pg': + specifier: ^8.10.2 + version: 8.10.2 '@types/supertest': specifier: ^2.0.12 version: 2.0.12 @@ -3681,6 +3684,14 @@ packages: /@types/normalize-package-data@2.4.1: resolution: {integrity: sha512-Gj7cI7z+98M282Tqmp2K5EIsoouUEzbBJhQQzDE3jSIRk6r9gsz0oUokqIUR4u1R3dMHo0pDHM7sNOHyhulypw==} + /@types/pg@8.10.2: + resolution: {integrity: sha512-MKFs9P6nJ+LAeHLU3V0cODEOgyThJ3OAnmOlsZsxux6sfQs3HRXR5bBn7xG5DjckEFhTAxsXi7k7cd0pCMxpJw==} + dependencies: + '@types/node': 18.0.0 + pg-protocol: 1.6.0 + pg-types: 4.0.1 + dev: true + /@types/pluralize@0.0.29: resolution: {integrity: sha512-BYOID+l2Aco2nBik+iYS4SZX0Lf20KPILP5RGmM1IgzdwNdTs0eebiFriOPcej1sX9mLnSoiNte5zcFxssgpGA==} dev: true @@ -4737,6 +4748,11 @@ packages: resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} dev: true + /buffer-writer@2.0.0: + resolution: {integrity: sha512-a7ZpuTZU1TRtnwyCNW3I5dc0wWNC3VR9S++Ewyk2HHZdrO3CQJqSpd+95Us590V6AL7JqUAH2IwZ/398PmNFgw==} + engines: {node: '>=4'} + dev: false + /buffer@5.7.1: resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} dependencies: @@ -8829,6 +8845,10 @@ packages: object-keys: 1.1.1 dev: true + /obuf@1.1.2: + resolution: {integrity: sha512-PX1wu0AmAdPqOL1mWhqmlOd8kOIZQwGZw6rh7uby9fTc5lhaOWFLX3I6R1hrF9k3zUY40e6igsLGkDXK92LJNg==} + dev: true + /on-exit-leak-free@2.1.0: resolution: {integrity: sha512-VuCaZZAjReZ3vUwgOB8LxAosIurDiAW0s13rI1YwmaP++jvcxP77AWoQvenZebpCA2m8WC1/EosPYPMjnRAp/w==} dev: true @@ -8963,6 +8983,10 @@ packages: resolution: {integrity: sha512-R4nPAVTAU0B9D35/Gk3uJf/7XYbQcyohSKdvAxIRSNghFl4e71hVoGnBNQz9cWaXxO2I10KTC+3jMdvvoKw6dQ==} engines: {node: '>=6'} + /packet-reader@1.0.0: + resolution: {integrity: sha512-HAKu/fG3HpHFO0AA8WE8q2g+gBJaZ9MG7fcKk+IJPLTGAD6Psw4443l+9DGRbOIh3/aXr7Phy0TjilYivJo5XQ==} + dev: false + /param-case@3.0.4: resolution: {integrity: sha512-RXlj7zCYokReqWpOPH9oYivUzLYZ5vAPIfEmCTNViosC78F8F0H9y7T7gG2M39ymgutxF5gcFEsyZQSph9Bp3A==} dependencies: @@ -9075,6 +9099,86 @@ packages: is-reference: 3.0.1 dev: true + /pg-cloudflare@1.1.1: + resolution: {integrity: sha512-xWPagP/4B6BgFO+EKz3JONXv3YDgvkbVrGw2mTo3D6tVDQRh1e7cqVGvyR3BE+eQgAvx1XhW/iEASj4/jCWl3Q==} + requiresBuild: true + dev: false + optional: true + + /pg-connection-string@2.6.1: + resolution: {integrity: sha512-w6ZzNu6oMmIzEAYVw+RLK0+nqHPt8K3ZnknKi+g48Ak2pr3dtljJW3o+D/n2zzCG07Zoe9VOX3aiKpj+BN0pjg==} + dev: false + + /pg-int8@1.0.1: + resolution: {integrity: sha512-WCtabS6t3c8SkpDBUlb1kjOs7l66xsGdKpIPZsg4wR+B3+u9UAum2odSsF9tnvxg80h4ZxLWMy4pRjOsFIqQpw==} + engines: {node: '>=4.0.0'} + + /pg-numeric@1.0.2: + resolution: {integrity: sha512-BM/Thnrw5jm2kKLE5uJkXqqExRUY/toLHda65XgFTBTFYZyopbKjBe29Ii3RbkvlsMoFwD+tHeGaCjjv0gHlyw==} + engines: {node: '>=4'} + dev: true + + /pg-pool@3.6.1(pg@8.11.1): + resolution: {integrity: sha512-jizsIzhkIitxCGfPRzJn1ZdcosIt3pz9Sh3V01fm1vZnbnCMgmGl5wvGGdNN2EL9Rmb0EcFoCkixH4Pu+sP9Og==} + peerDependencies: + pg: '>=8.0' + dependencies: + pg: 8.11.1 + dev: false + + /pg-protocol@1.6.0: + resolution: {integrity: sha512-M+PDm637OY5WM307051+bsDia5Xej6d9IR4GwJse1qA1DIhiKlksvrneZOYQq42OM+spubpcNYEo2FcKQrDk+Q==} + + /pg-types@2.2.0: + resolution: {integrity: sha512-qTAAlrEsl8s4OiEQY69wDvcMIdQN6wdz5ojQiOy6YRMuynxenON0O5oCpJI6lshc6scgAY8qvJ2On/p+CXY0GA==} + engines: {node: '>=4'} + dependencies: + pg-int8: 1.0.1 + postgres-array: 2.0.0 + postgres-bytea: 1.0.0 + postgres-date: 1.0.7 + postgres-interval: 1.2.0 + dev: false + + /pg-types@4.0.1: + resolution: {integrity: sha512-hRCSDuLII9/LE3smys1hRHcu5QGcLs9ggT7I/TCs0IE+2Eesxi9+9RWAAwZ0yaGjxoWICF/YHLOEjydGujoJ+g==} + engines: {node: '>=10'} + dependencies: + pg-int8: 1.0.1 + pg-numeric: 1.0.2 + postgres-array: 3.0.2 + postgres-bytea: 3.0.0 + postgres-date: 2.0.1 + postgres-interval: 3.0.0 + postgres-range: 1.1.3 + dev: true + + /pg@8.11.1: + resolution: {integrity: sha512-utdq2obft07MxaDg0zBJI+l/M3mBRfIpEN3iSemsz0G5F2/VXx+XzqF4oxrbIZXQxt2AZzIUzyVg/YM6xOP/WQ==} + engines: {node: '>= 8.0.0'} + peerDependencies: + pg-native: '>=3.0.1' + peerDependenciesMeta: + pg-native: + optional: true + dependencies: + buffer-writer: 2.0.0 + packet-reader: 1.0.0 + pg-connection-string: 2.6.1 + pg-pool: 3.6.1(pg@8.11.1) + pg-protocol: 1.6.0 + pg-types: 2.2.0 + pgpass: 1.0.5 + optionalDependencies: + pg-cloudflare: 1.1.1 + dev: false + + /pgpass@1.0.5: + resolution: {integrity: sha512-FdW9r/jQZhSeohs1Z3sI1yxFQNFvMcnmfuj4WBMUTxOrAyLMaTcE1aAMBiTlbMNaXvBCQuVi0R7hd8udDSP7ug==} + dependencies: + split2: 4.2.0 + dev: false + /picocolors@1.0.0: resolution: {integrity: sha512-1fygroTLlHu66zi26VoTDv8yRgm0Fccecssto+MhsZ0D/DGW2sm8E8AjW7NU5VVTRt5GxbeZ5qBuJr+HyLYkjQ==} dev: true @@ -9188,6 +9292,54 @@ packages: source-map-js: 1.0.2 dev: true + /postgres-array@2.0.0: + resolution: {integrity: sha512-VpZrUqU5A69eQyW2c5CA1jtLecCsN2U/bD6VilrFDWq5+5UIEVO7nazS3TEcHf1zuPYO/sqGvUvW62g86RXZuA==} + engines: {node: '>=4'} + dev: false + + /postgres-array@3.0.2: + resolution: {integrity: sha512-6faShkdFugNQCLwucjPcY5ARoW1SlbnrZjmGl0IrrqewpvxvhSLHimCVzqeuULCbG0fQv7Dtk1yDbG3xv7Veog==} + engines: {node: '>=12'} + dev: true + + /postgres-bytea@1.0.0: + resolution: {integrity: sha512-xy3pmLuQqRBZBXDULy7KbaitYqLcmxigw14Q5sj8QBVLqEwXfeybIKVWiqAXTlcvdvb0+xkOtDbfQMOf4lST1w==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-bytea@3.0.0: + resolution: {integrity: sha512-CNd4jim9RFPkObHSjVHlVrxoVQXz7quwNFpz7RY1okNNme49+sVyiTvTRobiLV548Hx/hb1BG+iE7h9493WzFw==} + engines: {node: '>= 6'} + dependencies: + obuf: 1.1.2 + dev: true + + /postgres-date@1.0.7: + resolution: {integrity: sha512-suDmjLVQg78nMK2UZ454hAG+OAW+HQPZ6n++TNDUX+L0+uUlLywnoxJKDou51Zm+zTCjrCl0Nq6J9C5hP9vK/Q==} + engines: {node: '>=0.10.0'} + dev: false + + /postgres-date@2.0.1: + resolution: {integrity: sha512-YtMKdsDt5Ojv1wQRvUhnyDJNSr2dGIC96mQVKz7xufp07nfuFONzdaowrMHjlAzY6GDLd4f+LUHHAAM1h4MdUw==} + engines: {node: '>=12'} + dev: true + + /postgres-interval@1.2.0: + resolution: {integrity: sha512-9ZhXKM/rw350N1ovuWHbGxnGh/SNJ4cnxHiM0rxE4VN41wsg8P8zWn9hv/buK00RP4WvlOyr/RBDiptyxVbkZQ==} + engines: {node: '>=0.10.0'} + dependencies: + xtend: 4.0.2 + dev: false + + /postgres-interval@3.0.0: + resolution: {integrity: sha512-BSNDnbyZCXSxgA+1f5UU2GmwhoI0aU5yMxRGO8CdFEcY2BQF9xm/7MqKnYoM1nJDk8nONNWDk9WeSmePFhQdlw==} + engines: {node: '>=12'} + dev: true + + /postgres-range@1.1.3: + resolution: {integrity: sha512-VdlZoocy5lCP0c/t66xAfclglEapXPCIVhqqJRncYpvbCgImF0w67aPKfbqUMr72tO2k5q0TdTZwCLjPTI6C9g==} + dev: true + /prebuild-install@7.1.1: resolution: {integrity: sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==} engines: {node: '>=10'} @@ -10018,7 +10170,6 @@ packages: /split2@4.2.0: resolution: {integrity: sha512-UcjcJOWknrNkF6PLX83qcHM6KHgVKNkV62Y8a5uYDVv9ydGQVwAHMKqHdJje1VTWpljG0WYpCDhrCdAOYH4TWg==} engines: {node: '>= 10.x'} - dev: true /sprintf-js@1.0.3: resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} diff --git a/tests/integration/package.json b/tests/integration/package.json index 7c2c9d096..76ae48bb7 100644 --- a/tests/integration/package.json +++ b/tests/integration/package.json @@ -14,6 +14,7 @@ "@types/bcryptjs": "^2.4.2", "@types/fs-extra": "^11.0.1", "@types/jest": "^29.5.0", + "@types/pg": "^8.10.2", "@types/supertest": "^2.0.12", "@types/tmp": "^0.2.3", "@types/uuid": "^8.3.4", @@ -40,6 +41,7 @@ "@zenstackhq/testtools": "workspace:*", "bcryptjs": "^2.4.3", "decimal.js": "^10.4.2", + "pg": "^8.11.1", "sleep-promise": "^9.1.0", "superjson": "^1.11.0" } diff --git a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts index 2d7326ed8..3cec4e02d 100644 --- a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts +++ b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts @@ -154,7 +154,6 @@ describe('With Policy:deep nested', () => { myId: '1', m2: { create: { - id: 201, value: 1, m3: { create: { @@ -308,7 +307,6 @@ describe('With Policy:deep nested', () => { data: { m2: { create: { - id: 201, value: 2, m3: { create: { id: 'm3-1', value: 11 }, @@ -401,18 +399,7 @@ describe('With Policy:deep nested', () => { await expect( db.m1.update({ where: { myId: '2' }, - data: { - m2: { - update: { - m4: { - updateMany: { - where: { value: { gt: 0 } }, - data: { value: 102 }, - }, - }, - }, - }, - }, + data: { value: 1 }, }) ).toBeRejectedByPolicy(); diff --git a/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts b/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts index d2e4095e2..4a1a4d0c5 100644 --- a/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts +++ b/tests/integration/tests/enhancements/with-policy/empty-policy.test.ts @@ -13,16 +13,18 @@ describe('With Policy:empty policy', () => { }); it('direct operations', async () => { - const { withPolicy } = await loadSchema( + const { prisma, withPolicy } = await loadSchema( ` model Model { id String @id @default(uuid()) + value Int } ` ); const db = withPolicy(); + await prisma.model.create({ data: { id: '1', value: 0 } }); await expect(db.model.create({ data: {} })).toBeRejectedByPolicy(); expect(await db.model.findMany()).toHaveLength(0); @@ -34,22 +36,24 @@ describe('With Policy:empty policy', () => { await expect(db.model.create({ data: {} })).toBeRejectedByPolicy(); await expect(db.model.createMany({ data: [{}] })).toBeRejectedByPolicy(); - await expect(db.model.update({ where: { id: '1' }, data: {} })).toBeRejectedByPolicy(); - await expect(db.model.updateMany({ data: {} })).toBeRejectedByPolicy(); + await expect(db.model.update({ where: { id: '1' }, data: { value: 1 } })).toBeRejectedByPolicy(); + await expect(db.model.updateMany({ data: { value: 1 } })).toBeRejectedByPolicy(); await expect( db.model.upsert({ where: { id: '1' }, - create: {}, - update: {}, + create: { value: 1 }, + update: { value: 1 }, }) ).toBeRejectedByPolicy(); await expect(db.model.delete({ where: { id: '1' } })).toBeRejectedByPolicy(); await expect(db.model.deleteMany()).toBeRejectedByPolicy(); - await expect(db.model.aggregate({})).toBeRejectedByPolicy(); - await expect(db.model.groupBy({})).toBeRejectedByPolicy(); - await expect(db.model.count()).toBeRejectedByPolicy(); + await expect(db.model.aggregate({ _avg: { value: true } })).resolves.toEqual( + expect.objectContaining({ _avg: { value: null } }) + ); + await expect(db.model.groupBy({ by: ['id'], _avg: { value: true } })).resolves.toHaveLength(0); + await expect(db.model.count()).resolves.toEqual(0); }); it('to-many write', async () => { diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts index 75a8d0278..d0511dd5f 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-many.test.ts @@ -347,7 +347,7 @@ describe('With Policy:nested to-many', () => { }, }, }) - ).toBeRejectedWithCode('P2017'); + ).toBeNotFound(); await expect( db.m1.update({ diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 59718a84c..645617f8c 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -12,7 +12,7 @@ describe('With Policy:nested to-one', () => { process.chdir(origDir); }); - it('read fitering for optional relation', async () => { + it('read filtering for optional relation', async () => { const { prisma, withPolicy } = await loadSchema( ` model M1 { @@ -119,7 +119,8 @@ describe('With Policy:nested to-one', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - ` + `, + { logPrismaQuery: true } ); const db = withPolicy(); @@ -179,7 +180,8 @@ describe('With Policy:nested to-one', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - ` + `, + { logPrismaQuery: true } ); const db = withPolicy(); diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index dd55caa9c..e4e45c0be 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -93,7 +93,8 @@ describe('With Policy: post update', () => { @@allow('create,read', true) @@allow('update', future().value > 1) } - ` + `, + { logPrismaQuery: true } ); const db = withPolicy(); @@ -226,7 +227,7 @@ describe('With Policy: post update', () => { where: { id: '1' }, data: { m2: { update: { value: 0 } } }, }) - ).toBeRejectedByPolicy(); + ).toResolveTruthy(); // m2 updatable await expect( db.m1.update({ diff --git a/tests/integration/tests/enhancements/with-policy/postgres.test.ts b/tests/integration/tests/enhancements/with-policy/postgres.test.ts new file mode 100644 index 000000000..6203aa991 --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/postgres.test.ts @@ -0,0 +1,526 @@ +import { AuthUser } from '@zenstackhq/runtime'; +import { loadSchemaFromFile, type WeakDbClientContract } from '@zenstackhq/testtools'; +import path from 'path'; +import { Pool } from 'pg'; + +const DB_NAME = 'todo-pg'; + +describe('With Policy: with postgres', () => { + let origDir: string; + let getDb: (user?: AuthUser) => WeakDbClientContract; + let prisma: WeakDbClientContract; + + const pool = new Pool({ user: 'postgres', password: 'abc123' }); + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + beforeEach(async () => { + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + await pool.query(`CREATE DATABASE "${DB_NAME}";`); + + const { prisma: _prisma, withPolicy } = await loadSchemaFromFile( + path.join(__dirname, '../../schema/todo-pg.zmodel'), + { + addPrelude: false, + } + ); + getDb = withPolicy; + prisma = _prisma; + }); + + afterEach(async () => { + process.chdir(origDir); + await prisma.$disconnect(); + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + }); + + it('user', async () => { + const user1 = { + id: 'user1', + email: 'user1@zenstack.dev', + name: 'User 1', + }; + const user2 = { + id: 'user2', + email: 'user2@zenstack.dev', + name: 'User 2', + }; + + const anonDb = getDb(); + const user1Db = getDb({ id: user1.id }); + const user2Db = getDb({ id: user2.id }); + + // create user1 + // create should succeed but result can be read back anonymously + await expect(anonDb.user.create({ data: user1 })).toBeRejectedByPolicy(); + await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveTruthy(); + await expect(user2Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); + + // create user2 + await expect(anonDb.user.create({ data: user2 })).toBeRejectedByPolicy(); + + // find with user1 should only get user1 + const r = await user1Db.user.findMany(); + expect(r).toHaveLength(1); + expect(r[0]).toEqual(expect.objectContaining(user1)); + + // get user2 as user1 + await expect(user1Db.user.findUnique({ where: { id: user2.id } })).toResolveNull(); + + // add both users into the same space + await expect( + user1Db.space.create({ + data: { + name: 'Space 1', + slug: 'space1', + owner: { connect: { id: user1.id } }, + members: { + create: [ + { + user: { connect: { id: user1.id } }, + role: 'ADMIN', + }, + { + user: { connect: { id: user2.id } }, + role: 'USER', + }, + ], + }, + }, + }) + ).toResolveTruthy(); + + // now both user1 and user2 should be visible + await expect(user1Db.user.findMany()).resolves.toHaveLength(2); + await expect(user2Db.user.findMany()).resolves.toHaveLength(2); + + // update user2 as user1 + await expect( + user2Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }) + ).toBeRejectedByPolicy(); + + // update user1 as user1 + await expect( + user1Db.user.update({ + where: { id: user1.id }, + data: { name: 'hello' }, + }) + ).toResolveTruthy(); + + // delete user2 as user1 + await expect(user1Db.user.delete({ where: { id: user2.id } })).toBeRejectedByPolicy(); + + // delete user1 as user1 + await expect(user1Db.user.delete({ where: { id: user1.id } })).toResolveTruthy(); + await expect(user1Db.user.findUnique({ where: { id: user1.id } })).toResolveNull(); + }); + + it('todo list', async () => { + await createSpaceAndUsers(prisma); + + const anonDb = getDb(); + const emptyUIDDb = getDb({ id: '' }); + const user1Db = getDb({ id: user1.id }); + const user2Db = getDb({ id: user2.id }); + const user3Db = getDb({ id: user3.id }); + + await expect( + anonDb.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + await expect( + user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }) + ).toResolveTruthy(); + + await expect(user1Db.list.findMany()).resolves.toHaveLength(1); + await expect(anonDb.list.findMany()).resolves.toHaveLength(0); + await expect(emptyUIDDb.list.findMany()).resolves.toHaveLength(0); + await expect(anonDb.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + + // accessible to owner + await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).resolves.toEqual( + expect.objectContaining({ id: 'list1', title: 'List 1' }) + ); + + // accessible to user in the space + await expect(user2Db.list.findUnique({ where: { id: 'list1' } })).toResolveTruthy(); + + // inaccessible to user not in the space + await expect(user3Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + + // make a private list + await user1Db.list.create({ + data: { + id: 'list2', + title: 'List 2', + private: true, + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // accessible to owner + await expect(user1Db.list.findUnique({ where: { id: 'list2' } })).toResolveTruthy(); + + // inaccessible to other user in the space + await expect(user2Db.list.findUnique({ where: { id: 'list2' } })).toResolveNull(); + + // create a list which doesn't match credential should fail + await expect( + user1Db.list.create({ + data: { + id: 'list3', + title: 'List 3', + owner: { connect: { id: user2.id } }, + space: { connect: { id: space1.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // create a list which doesn't match credential's space should fail + await expect( + user1Db.list.create({ + data: { + id: 'list3', + title: 'List 3', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space2.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // update list + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + title: 'List 1 updated', + }, + }) + ).resolves.toEqual(expect.objectContaining({ title: 'List 1 updated' })); + + await expect( + user2Db.list.update({ + where: { id: 'list1' }, + data: { + title: 'List 1 updated', + }, + }) + ).toBeRejectedByPolicy(); + + // delete list + await expect(user2Db.list.delete({ where: { id: 'list1' } })).toBeRejectedByPolicy(); + await expect(user1Db.list.delete({ where: { id: 'list1' } })).toResolveTruthy(); + await expect(user1Db.list.findUnique({ where: { id: 'list1' } })).toResolveNull(); + }); + + it('todo', async () => { + await createSpaceAndUsers(prisma); + + const user1Db = getDb({ id: user1.id }); + const user2Db = getDb({ id: user2.id }); + + // create a public list + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // create + await expect( + user1Db.todo.create({ + data: { + id: 'todo1', + title: 'Todo 1', + owner: { connect: { id: user1.id } }, + list: { + connect: { id: 'list1' }, + }, + }, + }) + ).toResolveTruthy(); + + await expect( + user2Db.todo.create({ + data: { + id: 'todo2', + title: 'Todo 2', + owner: { connect: { id: user2.id } }, + list: { + connect: { id: 'list1' }, + }, + }, + }) + ).toResolveTruthy(); + + // read + await expect(user1Db.todo.findMany()).resolves.toHaveLength(2); + await expect(user2Db.todo.findMany()).resolves.toHaveLength(2); + + // update, user in the same space can freely update + await expect( + user1Db.todo.update({ + where: { id: 'todo1' }, + data: { + title: 'Todo 1 updated', + }, + }) + ).toResolveTruthy(); + await expect( + user1Db.todo.update({ + where: { id: 'todo2' }, + data: { + title: 'Todo 2 updated', + }, + }) + ).toResolveTruthy(); + + // create a private list + await user1Db.list.create({ + data: { + id: 'list2', + private: true, + title: 'List 2', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + // create + await expect( + user1Db.todo.create({ + data: { + id: 'todo3', + title: 'Todo 3', + owner: { connect: { id: user1.id } }, + list: { + connect: { id: 'list2' }, + }, + }, + }) + ).toResolveTruthy(); + + // reject because list2 is private + await expect( + user2Db.todo.create({ + data: { + id: 'todo4', + title: 'Todo 4', + owner: { connect: { id: user2.id } }, + list: { + connect: { id: 'list2' }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // update, only owner can update todo in a private list + await expect( + user1Db.todo.update({ + where: { id: 'todo3' }, + data: { + title: 'Todo 3 updated', + }, + }) + ).toResolveTruthy(); + await expect( + user2Db.todo.update({ + where: { id: 'todo3' }, + data: { + title: 'Todo 3 updated', + }, + }) + ).toBeRejectedByPolicy(); + }); + + it('relation query', async () => { + await createSpaceAndUsers(prisma); + + const user1Db = getDb({ id: user1.id }); + const user2Db = getDb({ id: user2.id }); + + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + await user1Db.list.create({ + data: { + id: 'list2', + title: 'List 2', + private: true, + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + }, + }); + + const r = await user1Db.space.findFirst({ + where: { id: 'space1' }, + include: { lists: true }, + }); + expect(r.lists).toHaveLength(2); + + const r1 = await user2Db.space.findFirst({ + where: { id: 'space1' }, + include: { lists: true }, + }); + expect(r1.lists).toHaveLength(1); + }); + + it('post-update checks', async () => { + await createSpaceAndUsers(prisma); + + const user1Db = getDb({ id: user1.id }); + + await user1Db.list.create({ + data: { + id: 'list1', + title: 'List 1', + owner: { connect: { id: user1.id } }, + space: { connect: { id: space1.id } }, + todos: { + create: { + id: 'todo1', + title: 'Todo 1', + owner: { connect: { id: user1.id } }, + }, + }, + }, + }); + + // change list's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // change todo's owner + await expect( + user1Db.todo.update({ + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }) + ).toBeRejectedByPolicy(); + + // nested change todo's owner + await expect( + user1Db.list.update({ + where: { id: 'list1' }, + data: { + todos: { + update: { + where: { id: 'todo1' }, + data: { + owner: { connect: { id: user2.id } }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + }); +}); + +const user1 = { + id: 'user1', + email: 'user1@zenstack.dev', + name: 'User 1', +}; + +const user2 = { + id: 'user2', + email: 'user2@zenstack.dev', + name: 'User 2', +}; + +const user3 = { + id: 'user3', + email: 'user3@zenstack.dev', + name: 'User 3', +}; + +const space1 = { + id: 'space1', + name: 'Space 1', + slug: 'space1', +}; + +const space2 = { + id: 'space2', + name: 'Space 2', + slug: 'space2', +}; + +async function createSpaceAndUsers(db: WeakDbClientContract) { + // create users + await db.user.create({ data: user1 }); + await db.user.create({ data: user2 }); + await db.user.create({ data: user3 }); + + // add user1 and user2 into space1 + await db.space.create({ + data: { + ...space1, + members: { + create: [ + { + user: { connect: { id: user1.id } }, + role: 'ADMIN', + }, + { + user: { connect: { id: user2.id } }, + role: 'USER', + }, + ], + }, + }, + }); + + // add user3 to space2 + await db.space.create({ + data: { + ...space2, + members: { + create: [ + { + user: { connect: { id: user3.id } }, + role: 'ADMIN', + }, + ], + }, + }, + }); +} diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts new file mode 100644 index 000000000..a4666c18e --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -0,0 +1,1143 @@ +import { AuthUser, PrismaErrorCode } from '@zenstackhq/runtime'; +import { loadSchemaFromFile, type WeakDbClientContract } from '@zenstackhq/testtools'; +import path from 'path'; +import { Pool } from 'pg'; + +const DB_NAME = 'refactor'; + +describe('With Policy: refactor tests', () => { + let origDir: string; + let getDb: (user?: AuthUser) => WeakDbClientContract; + let prisma: WeakDbClientContract; + let anonDb: WeakDbClientContract; + let adminDb: WeakDbClientContract; + let user1Db: WeakDbClientContract; + let user2Db: WeakDbClientContract; + + const pool = new Pool({ user: 'postgres', password: 'abc123' }); + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + beforeEach(async () => { + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + await pool.query(`CREATE DATABASE "${DB_NAME}";`); + + const { prisma: _prisma, withPolicy } = await loadSchemaFromFile( + path.join(__dirname, '../../schema/refactor-pg.zmodel'), + { + addPrelude: false, + logPrismaQuery: true, + } + ); + getDb = withPolicy; + prisma = _prisma; + anonDb = getDb(); + user1Db = getDb({ id: 1 }); + user2Db = getDb({ id: 2 }); + adminDb = getDb({ id: 100, role: 'ADMIN' }); + }); + + afterEach(async () => { + process.chdir(origDir); + await prisma.$disconnect(); + await pool.query(`DROP DATABASE IF EXISTS "${DB_NAME}";`); + }); + + it('read', async () => { + // empty table + await expect(anonDb.user.findMany()).resolves.toHaveLength(0); + await expect(anonDb.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect(anonDb.user.findUniqueOrThrow({ where: { id: 1 } })).toBeNotFound(); + await expect(anonDb.user.findFirst({ where: { id: 1 } })).toResolveNull(); + await expect(anonDb.user.findFirstOrThrow({ where: { id: 1 } })).toBeNotFound(); + + await prisma.user.create({ + data: { + id: 1, + email: 'user1@zenstack.dev', + profile: { + create: { + name: 'User 1', + private: true, + }, + }, + posts: { + create: [ + { + title: 'Post 1', + published: true, + comments: { create: { id: 1, authorId: 1, content: 'Comment 1' } }, + }, + { + title: 'Post 2', + published: false, + comments: { create: { id: 2, authorId: 1, content: 'Comment 2' } }, + }, + ], + }, + }, + }); + + // simple read + await expect(anonDb.user.findMany()).resolves.toHaveLength(0); + await expect(adminDb.user.findMany()).resolves.toHaveLength(1); + await expect(user1Db.user.findMany()).resolves.toHaveLength(1); + await expect(user2Db.user.findMany()).resolves.toHaveLength(1); + await expect(anonDb.user.findUnique({ where: { id: 1 } })).toResolveNull(); + await expect(adminDb.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(user1Db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(user2Db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + + // included profile got filtered + await expect(user1Db.user.findUnique({ include: { profile: true }, where: { id: 1 } })).resolves.toMatchObject({ + email: 'user1@zenstack.dev', + profile: expect.objectContaining({ name: 'User 1' }), + }); + await expect(user2Db.user.findUnique({ include: { profile: true }, where: { id: 1 } })).resolves.toMatchObject({ + email: 'user1@zenstack.dev', + profile: null, + }); + + // filter by profile + await expect(user1Db.user.findFirst({ where: { profile: { name: 'User 1' } } })).toResolveTruthy(); + await expect(user2Db.user.findFirst({ where: { profile: { name: 'User 1' } } })).toResolveFalsy(); + + // include profile cause toplevel user got filtered + await expect(user1Db.profile.findUnique({ include: { user: true }, where: { userId: 1 } })).toResolveTruthy(); + await expect(user2Db.profile.findUnique({ include: { user: true }, where: { userId: 1 } })).toResolveNull(); + + // posts got filtered + expect((await user1Db.user.findUnique({ include: { posts: true }, where: { id: 1 } })).posts).toHaveLength(2); + expect((await user2Db.user.findUnique({ include: { posts: true }, where: { id: 1 } })).posts).toHaveLength(1); + + // filter by posts + await expect( + user1Db.user.findFirst({ + where: { posts: { some: { title: 'Post 2' } } }, + }) + ).toResolveTruthy(); + await expect( + user2Db.user.findFirst({ + where: { posts: { some: { title: 'Post 2' } } }, + }) + ).toResolveFalsy(); + + // deep filter with comment + await expect( + user1Db.user.findFirst({ where: { posts: { some: { comments: { every: { content: 'Comment 2' } } } } } }) + ).toResolveTruthy(); + await expect( + user2Db.user.findFirst({ where: { posts: { some: { comments: { every: { content: 'Comment 2' } } } } } }) + ).toResolveNull(); + }); + + it('create', async () => { + // validation check + await expect( + anonDb.user.create({ + data: { email: 'abcd' }, + }) + ).toBeRejectedByPolicy(); + + // read back check + await expect( + anonDb.user.create({ + data: { id: 1, email: 'user1@zenstack.dev' }, + }) + ).rejects.toThrow(/not allowed to be read back/); + + // success + await expect(user1Db.user.findUnique({ where: { id: 1 } })).toResolveTruthy(); + + // nested creation failure + await expect( + anonDb.user.create({ + data: { + id: 2, + email: 'user2@zenstack.dev', + posts: { + create: { + id: 2, + title: 'A very long post title', + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // check no partial creation + await expect(adminDb.user.findUnique({ where: { id: 2 } })).toResolveFalsy(); + + // deeply nested creation failure + await expect( + anonDb.user.create({ + data: { + id: 2, + email: 'user2@zenstack.dev', + posts: { + create: { + id: 2, + title: 'Post 2', + comments: { + create: { + authorId: 1, + content: 'Comment 2', + }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // check no partial creation + await expect(adminDb.user.findUnique({ where: { id: 2 } })).toResolveFalsy(); + + // deeply nested creation success + await expect( + user2Db.user.create({ + data: { + id: 2, + email: 'user2@zenstack.dev', + posts: { + create: { + id: 2, + title: 'Post 2', + published: true, + comments: { + create: { + authorId: 2, + content: 'Comment 2', + }, + }, + }, + }, + }, + }) + ).toResolveTruthy(); + + // create with connect: posts + await expect( + anonDb.user.create({ + data: { + id: 3, + email: 'user3@zenstack.dev', + posts: { + connect: { id: 3 }, + }, + }, + }) + ).toBeNotFound(); + await adminDb.post.create({ + data: { id: 3, authorId: 1, title: 'Post 3' }, + }); + await expect( + anonDb.user.create({ + data: { + id: 3, + email: 'user3@zenstack.dev', + posts: { + connect: { id: 3 }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + anonDb.user.create({ + data: { + id: 3, + email: 'user3@zenstack.dev', + posts: { + connectOrCreate: { where: { id: 3 }, create: { title: 'Post 3' } }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // success + await expect( + adminDb.user.create({ + data: { + id: 3, + email: 'user3@zenstack.dev', + posts: { + connect: { id: 3 }, + }, + }, + }) + ).toResolveTruthy(); + const r = await adminDb.user.create({ + include: { posts: true }, + data: { + id: 4, + email: 'user4@zenstack.dev', + posts: { + connectOrCreate: { where: { id: 4 }, create: { title: 'Post 4' } }, + }, + }, + }); + expect(r.posts[0].title).toEqual('Post 4'); + + // create with connect: profile + await expect( + anonDb.user.create({ + data: { + id: 5, + email: 'user5@zenstack.dev', + profile: { + connect: { id: 5 }, + }, + }, + }) + ).toBeNotFound(); + await adminDb.profile.create({ + data: { id: 5, userId: 1, name: 'User 5' }, + }); + await expect( + anonDb.user.create({ + data: { + id: 5, + email: 'user5@zenstack.dev', + profile: { + connect: { id: 5 }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + anonDb.user.create({ + data: { + id: 5, + email: 'user5@zenstack.dev', + profile: { + connectOrCreate: { where: { id: 5 }, create: { name: 'User 5' } }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // success + await expect( + adminDb.user.create({ + data: { + id: 5, + email: 'user5@zenstack.dev', + profile: { + connect: { id: 5 }, + }, + }, + }) + ).toResolveTruthy(); + const r1 = await adminDb.user.create({ + include: { profile: true }, + data: { + id: 6, + email: 'user6@zenstack.dev', + profile: { + connectOrCreate: { where: { id: 6 }, create: { name: 'User 6' } }, + }, + }, + }); + expect(r1.profile.name).toEqual('User 6'); + + // createMany, policy violation + await expect( + anonDb.user.create({ + data: { + id: 7, + email: 'user7@zenstack.dev', + posts: { + createMany: { + data: [ + { id: 7, title: 'Post 7.1' }, + { id: 8, title: 'Post 7.2 very long title' }, + ], + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // no partial success + await expect(adminDb.user.findUnique({ where: { id: 7 } })).toResolveFalsy(); + + // createMany, unique constraint violation + await expect( + adminDb.user.create({ + data: { + id: 7, + email: 'user7@zenstack.dev', + posts: { + createMany: { + data: [ + { id: 7, title: 'Post 7.1' }, + { id: 7, title: 'Post 7.2' }, + ], + }, + }, + }, + }) + ).toBeRejectedWithCode(PrismaErrorCode.UNIQUE_CONSTRAINT_FAILED); + // no partial success + await expect(adminDb.user.findUnique({ where: { id: 7 } })).toResolveFalsy(); + + // createMany, skip duplicates + await expect( + adminDb.user.create({ + data: { + id: 7, + email: 'user7@zenstack.dev', + posts: { + createMany: { + data: [ + { id: 7, title: 'Post 7.1' }, + { id: 7, title: 'Post 7.2' }, + { id: 8, title: 'Post 8' }, + ], + skipDuplicates: true, + }, + }, + }, + }) + ).toResolveTruthy(); + // success + await expect(adminDb.user.findUnique({ where: { id: 7 } })).toResolveTruthy(); + await expect(adminDb.post.findUnique({ where: { id: 7 } })).toResolveTruthy(); + await expect(adminDb.post.findUnique({ where: { id: 8 } })).toResolveTruthy(); + }); + + it('createMany', async () => { + await prisma.user.create({ + data: { id: 1, email: 'user1@zenstack.dev' }, + }); + + // success + await expect( + user1Db.post.createMany({ + data: [ + { id: 1, title: 'Post 1', authorId: 1 }, + { id: 2, title: 'Post 2', authorId: 1 }, + ], + }) + ).resolves.toMatchObject({ count: 2 }); + + // unique constraint violation + await expect( + user1Db.post.createMany({ + data: [ + { id: 2, title: 'Post 2', authorId: 1 }, + { id: 3, title: 'Post 3', authorId: 1 }, + ], + }) + ).toBeRejectedWithCode(PrismaErrorCode.UNIQUE_CONSTRAINT_FAILED); + await expect(user1Db.post.findFirst({ where: { id: 3 } })).toResolveNull(); + + const r = await prisma.post.findMany(); + console.log('Existing:', JSON.stringify(r)); + + // ignore duplicates + await expect( + user1Db.post.createMany({ + data: [ + { id: 2, title: 'Post 2', authorId: 1 }, + { id: 3, title: 'Post 3', authorId: 1 }, + ], + skipDuplicates: true, + }) + ).resolves.toMatchObject({ count: 1 }); + await expect(user1Db.post.findFirst({ where: { id: 3 } })).toResolveTruthy(); + + // fail as a transaction + await expect( + user1Db.post.createMany({ + data: [ + { id: 4, title: 'Post 4 very very long', authorId: 1 }, + { id: 5, title: 'Post 5', authorId: 1 }, + ], + }) + ).toBeRejectedByPolicy(); + await expect(user1Db.post.findFirst({ where: { id: { in: [4, 5] } } })).toResolveNull(); + }); + + it('update', async () => { + await prisma.user.create({ + data: { + id: 2, + email: 'user2@zenstack.dev', + }, + }); + await prisma.user.create({ + data: { + id: 1, + email: 'user1@zenstack.dev', + profile: { + create: { + id: 1, + name: 'User 1', + private: true, + }, + }, + posts: { + create: [ + { + id: 1, + title: 'Post 1', + published: true, + comments: { create: { authorId: 1, content: 'Comment 1' } }, + }, + { + id: 2, + title: 'Post 2', + published: false, + comments: { create: { authorId: 2, content: 'Comment 2' } }, + }, + ], + }, + }, + }); + + // top-level + await expect(anonDb.user.update({ where: { id: 3 }, data: { email: 'user2@zenstack.dev' } })).toBeNotFound(); + await expect( + anonDb.user.update({ where: { id: 1 }, data: { email: 'user2@zenstack.dev' } }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.user.update({ where: { id: 1 }, data: { email: 'user2@zenstack.dev' } }) + ).toBeRejectedByPolicy(); + await expect( + adminDb.user.update({ where: { id: 1 }, data: { email: 'user1-nice@zenstack.dev' } }) + ).toResolveTruthy(); + + // update nested profile + await expect( + anonDb.user.update({ + where: { id: 1 }, + data: { profile: { update: { private: false } } }, + }) + ).toBeRejectedByPolicy(); + // variation: with where + await expect( + anonDb.user.update({ + where: { id: 1 }, + data: { profile: { update: { where: { private: true }, data: { private: false } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.user.update({ + where: { id: 1 }, + data: { profile: { update: { private: false } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { profile: { update: { private: false } } }, + }) + ).toResolveTruthy(); + // variation: with where + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { profile: { update: { where: { private: true }, data: { private: false } } } }, + }) + ).toBeNotFound(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { profile: { update: { where: { private: false }, data: { private: true } } } }, + }) + ).toResolveTruthy(); + + // update nested posts + await expect( + anonDb.user.update({ + where: { id: 1 }, + data: { posts: { update: { where: { id: 1 }, data: { published: false } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.user.update({ + where: { id: 1 }, + data: { posts: { update: { where: { id: 1 }, data: { published: false } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { posts: { update: { where: { id: 1 }, data: { published: false } } } }, + }) + ).toResolveTruthy(); + + // update nested comment prevent update of toplevel + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + email: 'user1-updated@zenstack.dev', + posts: { + update: { + where: { id: 2 }, + data: { + comments: { + update: { where: { content: 'Comment 2' }, data: { content: 'Comment 2 updated' } }, + }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect(adminDb.user.findUnique({ where: { email: 'user1-updated@zenstack.dev' } })).toResolveNull(); + await expect(adminDb.comment.findFirst({ where: { content: 'Comment 2 updated' } })).toResolveFalsy(); + + // update with create + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + create: { + id: 3, + title: 'Post 3', + published: true, + comments: { + create: { author: { connect: { id: 1 } }, content: 'Comment 3' }, + }, + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + create: { + id: 4, + title: 'Post 4', + published: false, + comments: { + create: { + // can't create comment for unpublished post + author: { connect: { id: 1 } }, + content: 'Comment 4', + }, + }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect(user1Db.post.findUnique({ where: { id: 4 } })).toResolveNull(); + + // update with createMany + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + createMany: { + data: [ + { id: 4, title: 'Post 4' }, + { id: 5, title: 'Post 5' }, + ], + }, + }, + }, + }) + ).toResolveTruthy(); + expect( + user1Db.user.update({ + include: { posts: true }, + where: { id: 1 }, + data: { + posts: { + createMany: { + data: [ + { id: 5, title: 'Post 5' }, + { id: 6, title: 'Post 6' }, + ], + }, + }, + }, + }) + ).toBeRejectedWithCode(PrismaErrorCode.UNIQUE_CONSTRAINT_FAILED); + const r = await user1Db.user.update({ + include: { posts: true }, + where: { id: 1 }, + data: { + posts: { + createMany: { + data: [ + { id: 5, title: 'Post 5' }, + { id: 6, title: 'Post 6' }, + ], + skipDuplicates: true, + }, + }, + }, + }); + expect(r.posts).toHaveLength(6); + + // update with update + // profile + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + profile: { + update: { + name: 'User1 updated', + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + profile: { + update: { + homepage: 'abc', // fail field validation + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.user.update({ + where: { id: 1 }, + data: { + profile: { + update: { + name: 'User1 updated again', + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // post + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + update: { + where: { id: 1 }, + data: { title: 'Post1-1' }, + }, + }, + }, + }) + ).toResolveTruthy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + update: { + where: { id: 1 }, + data: { title: 'Post1 very long' }, // fail field validation + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.user.update({ + where: { id: 1 }, + data: { + posts: { + update: { where: { id: 1 }, data: { title: 'Post1-2' } }, + }, + }, + }) + ).toBeRejectedByPolicy(); + // deep post + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + update: { + where: { id: 1 }, + data: { comments: { update: { where: { id: 1 }, data: { content: 'Comment1-1' } } } }, + }, + }, + }, + }) + ).toResolveTruthy(); + + // update with updateMany + // blocked by: https://github.com/prisma/prisma/issues/18371 + // await expect( + // user1Db.user.update({ + // where: { id: 1 }, + // data: { posts: { updateMany: { where: { id: { in: [1, 2, 3] } }, data: { title: 'My Post' } } } }, + // }) + // ).resolves.toMatchObject({ count: 3 }); + // await expect( + // user1Db.user.update({ + // where: { id: 1 }, + // data: { + // posts: { updateMany: { where: { id: { in: [1, 2, 3] } }, data: { title: 'Very long title' } } }, + // }, + // }) + // ).toBeRejectedByPolicy(); + // await expect( + // user2Db.user.update({ + // where: { id: 1 }, + // data: { posts: { updateMany: { where: { id: { in: [1, 2, 3] } }, data: { title: 'My Post' } } } }, + // }) + // ).toBeRejectedByPolicy(); + + // update with upsert + // post + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 1 }, + update: { title: 'Post 1-1' }, // update + create: { id: 1, title: 'Post 1' }, + }, + }, + }, + }) + ).toResolveTruthy(); + await expect(user1Db.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ title: 'Post 1-1' }); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 7 }, + update: { title: 'Post 7-1' }, + create: { id: 1, title: 'Post 7' }, // create + }, + }, + }, + }) + ).toResolveTruthy(); + await expect(user1Db.post.findUnique({ where: { id: 7 } })).resolves.toMatchObject({ title: 'Post 7' }); + await expect( + user2Db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 7 }, + update: { title: 'Post 7-1' }, + create: { id: 1, title: 'Post 7' }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { + posts: { + upsert: { + where: { id: 7 }, + update: { title: 'Post 7 very long' }, + create: { id: 1, title: 'Post 7' }, + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + + // update with connect + // post + await expect( + user1Db.user.update({ + where: { id: 2 }, + data: { + posts: { + connect: { id: 1 }, + }, + }, + }) + ).toResolveTruthy(); + await expect(adminDb.post.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ authorId: 2 }); + await expect( + user2Db.user.update({ + where: { id: 2 }, + data: { + posts: { + connect: { id: 2 }, // user2 can't update post2 + }, + }, + }) + ).toBeRejectedByPolicy(); + // profile + await expect( + user1Db.user.update({ where: { id: 2 }, data: { profile: { connect: { id: 1 } } } }) + ).toResolveTruthy(); + await expect(adminDb.profile.findUnique({ where: { id: 1 } })).resolves.toMatchObject({ userId: 2 }); + await expect( + user1Db.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 2 } } }, // user1 can't update profile1 + }) + ).toBeRejectedByPolicy(); + // reassign profile1 to user1 + await adminDb.user.update({ + where: { id: 1 }, + data: { profile: { connect: { id: 1 } } }, + }); + + // update with connectOrCreate + await expect( + user1Db.profile.update({ + where: { id: 1 }, + data: { + image: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, url: 'abc' }, // validation error + }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + user1Db.profile.update({ + where: { id: 1 }, + data: { + image: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, url: 'http://abc.com/pic.png' }, // create + }, + }, + }, + }) + ).toResolveTruthy(); + await expect(user1Db.image.findUnique({ where: { id: 1 } })).toResolveTruthy(); + await expect(user1Db.profile.findUnique({ include: { image: true }, where: { id: 1 } })).resolves.toMatchObject( + { id: 1 } + ); + await expect( + user1Db.profile.update({ + where: { id: 1 }, + data: { + image: { + connectOrCreate: { + where: { id: 1 }, + create: { id: 1, url: 'http://abc.com/pic1.png' }, // create + }, + }, + }, + }) + ).toResolveTruthy(); + await prisma.user.update({ + where: { id: 2 }, + data: { profile: { create: { id: 2, name: 'User 2' } } }, + }); + await prisma.image.create({ data: { id: 2, url: 'http://abc.com/pic2.png' } }); + await expect( + user1Db.profile.update({ + where: { id: 2 }, + data: { + image: { + // cause update to profile which is not allowed + connectOrCreate: { where: { id: 2 }, create: { id: 2, url: 'http://abc.com/pic2-1.png' } }, + }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.profile.update({ + where: { id: 2 }, + data: { + image: { + connectOrCreate: { + where: { id: 2 }, // connect + create: { id: 2, url: 'http://abc.com/pic2-1.png' }, + }, + }, + }, + }) + ).toResolveTruthy(); + await expect(user2Db.profile.findUnique({ include: { image: true }, where: { id: 2 } })).resolves.toMatchObject( + { + image: { url: 'http://abc.com/pic2.png' }, + } + ); + + // update with disconnect + await expect( + user1Db.profile.update({ + where: { id: 2 }, + data: { image: { disconnect: true } }, + }) + ).toBeRejectedByPolicy(); + await expect( + user2Db.profile.update({ + where: { id: 2 }, + data: { image: { disconnect: true } }, + }) + ).toResolveTruthy(); + await expect(user2Db.profile.findUnique({ include: { image: true }, where: { id: 2 } })).resolves.toMatchObject( + { image: null } + ); + + // update with set + await prisma.image.create({ data: { id: 3, url: 'http://abc.com/pic3.png' } }); + await prisma.image.create({ data: { id: 4, url: 'http://abc.com/pic4.png' } }); + await prisma.image.create({ data: { id: 5, url: 'http://abc.com/pic5.png' } }); + await prisma.image.create({ data: { id: 6, url: 'http://abc.com/pic6.png' } }); + + await expect( + user1Db.comment.update({ + where: { id: 1 }, + data: { + images: { set: [{ id: 3 }, { id: 4 }] }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + adminDb.comment.update({ + where: { id: 1 }, + data: { + images: { set: [{ id: 3 }, { id: 4 }] }, + }, + }) + ).toResolveTruthy(); + await expect(adminDb.image.findUnique({ where: { id: 3 } })).resolves.toMatchObject({ commentId: 1 }); + await expect(adminDb.image.findUnique({ where: { id: 4 } })).resolves.toMatchObject({ commentId: 1 }); + await expect( + adminDb.comment.update({ + where: { id: 1 }, + data: { + images: { set: [{ id: 5 }, { id: 6 }] }, + }, + }) + ).toResolveTruthy(); + await expect(adminDb.image.findUnique({ where: { id: 3 } })).resolves.toMatchObject({ commentId: null }); + await expect(adminDb.image.findUnique({ where: { id: 4 } })).resolves.toMatchObject({ commentId: null }); + await expect(adminDb.image.findUnique({ where: { id: 5 } })).resolves.toMatchObject({ commentId: 1 }); + await expect(adminDb.image.findUnique({ where: { id: 6 } })).resolves.toMatchObject({ commentId: 1 }); + + // update with delete + await expect( + user1Db.comment.update({ + where: { id: 1 }, + data: { + images: { delete: [{ id: 5 }, { id: 6 }] }, + }, + }) + ).toBeRejectedByPolicy(); + await expect( + adminDb.comment.update({ + where: { id: 1 }, + data: { + images: { delete: [{ id: 5 }, { id: 6 }] }, + }, + }) + ).toResolveTruthy(); + await expect(adminDb.image.findUnique({ where: { id: 5 } })).toResolveNull(); + await expect(adminDb.image.findUnique({ where: { id: 6 } })).toResolveNull(); + + // update with deleteMany + await prisma.comment.update({ + where: { id: 1 }, + data: { + images: { set: [{ id: 3 }, { id: 4 }] }, + }, + }); + await expect( + user1Db.comment.update({ + where: { id: 1 }, + data: { images: { deleteMany: { url: { contains: 'pic3' } } } }, + }) + ).toBeRejectedByPolicy(); + await expect( + adminDb.comment.update({ + where: { id: 1 }, + data: { images: { deleteMany: { url: { contains: 'pic3' } } } }, + }) + ).toResolveTruthy(); + await expect(adminDb.image.findUnique({ where: { id: 3 } })).toResolveNull(); + }); + + it('updateMany', async () => { + await prisma.user.create({ + data: { + id: 1, + email: 'user1@zenstack.dev', + profile: { + create: { id: 1, name: 'User 1', private: true }, + }, + posts: { + create: [ + { id: 1, title: 'Post 1' }, + { id: 2, title: 'Post 2' }, + ], + }, + }, + }); + await expect( + user2Db.post.updateMany({ + data: { title: 'My post' }, + }) + ).resolves.toMatchObject({ count: 0 }); + await expect( + user1Db.post.updateMany({ + data: { title: 'My long long post' }, + }) + ).toBeRejectedByPolicy(); + await expect( + user1Db.post.updateMany({ + data: { title: 'My post' }, + }) + ).resolves.toMatchObject({ count: 2 }); + }); + + it('delete', async () => { + await prisma.user.create({ + data: { + id: 1, + email: 'user1@zenstack.dev', + profile: { + create: { id: 1, name: 'User 1', private: true }, + }, + posts: { + create: [ + { id: 1, title: 'Post 1', published: true }, + { id: 2, title: 'Post 2', published: false }, + ], + }, + }, + }); + + await expect(user2Db.post.delete({ where: { id: 1 } })).toBeRejectedByPolicy(); + await expect(user1Db.post.delete({ where: { id: 1 } })).toResolveTruthy(); + }); + + it('deleteMany', async () => { + await prisma.user.create({ + data: { + id: 1, + email: 'user1@zenstack.dev', + profile: { + create: { id: 1, name: 'User 1', private: true }, + }, + posts: { + create: [ + { id: 1, title: 'Post 1', published: true }, + { id: 2, title: 'Post 2', published: false }, + ], + }, + }, + }); + + await expect(user2Db.post.deleteMany({ where: { published: true } })).resolves.toMatchObject({ count: 0 }); + await expect(user1Db.post.deleteMany({ where: { published: true } })).resolves.toMatchObject({ count: 1 }); + }); +}); diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index 15626e1c2..9bc78d302 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -72,7 +72,8 @@ describe('With Policy:toplevel operations', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - ` + `, + { logPrismaQuery: true } ); const db = withPolicy(); diff --git a/tests/integration/tests/schema/refactor-pg.zmodel b/tests/integration/tests/schema/refactor-pg.zmodel new file mode 100644 index 000000000..13c30ee65 --- /dev/null +++ b/tests/integration/tests/schema/refactor-pg.zmodel @@ -0,0 +1,100 @@ +datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/refactor' +} + +generator js { + provider = 'prisma-client-js' +} + +enum Role { + USER + ADMIN +} + +model User { + id Int @id @default(autoincrement()) + email String @unique @email + role Role @default(USER) + profile Profile? + posts Post[] + comments Comment[] + + // everybody can signup + @@allow('create', true) + + @@allow('read', auth() != null) + + // full-access by self + @@allow('all', auth() == this || auth().role == ADMIN) +} + +model Profile { + id Int @id @default(autoincrement()) + name String + homepage String? @url + private Boolean @default(false) + image Image? @relation(fields: [imageId], references: [id], onDelete: Cascade) + imageId Int? @unique + + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId Int @unique + + // user profile is publicly readable + @@allow('read', auth() != null && !private) + + // user profile is only updatable by the user + @@allow('all', auth() == user || auth().role == ADMIN) +} + +model Image { + id Int @id @default(autoincrement()) + url String @url + profile Profile? + + comment Comment? @relation(fields: [commentId], references: [id]) + commentId Int? + + @@allow('create,read', true) + @@allow('update,delete', auth().role == ADMIN) +} + +model Post { + id Int @id @default(autoincrement()) + title String @length(1, 8) + published Boolean @default(false) + comments Comment[] + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + + // posts are readable by all + @@allow('read', published) + + // posts are updatable by the author + @@allow('all', auth() == author || auth().role == ADMIN) +} + +model Comment { + id Int @id @default(autoincrement()) + content String + + author User @relation(fields: [authorId], references: [id], onDelete: Cascade) + authorId Int + + post Post @relation(fields: [postId], references: [id], onDelete: Cascade) + postId Int + + images Image[] + + // comments are readable by all + @@allow('read', post.published) + + @@allow('create', auth() != null && post.published && auth() == author) + + @@allow('update', auth() == author && future().author == auth()) + + @@allow('delete', auth() == author || auth() == post.author) + + // comments are updatable by the author + @@allow('all', auth().role == ADMIN) +} diff --git a/tests/integration/tests/schema/todo-pg.zmodel b/tests/integration/tests/schema/todo-pg.zmodel new file mode 100644 index 000000000..9553b59ab --- /dev/null +++ b/tests/integration/tests/schema/todo-pg.zmodel @@ -0,0 +1,152 @@ +/* +* Sample model for a collaborative Todo app +*/ + +datasource db { + provider = 'postgresql' + url = 'postgres://postgres:abc123@localhost:5432/todo-pg' +} + +generator js { + provider = 'prisma-client-js' +} + +enum UserRole { + ADMIN + USER +} + +/* + * Model for a space in which users can collaborate on Lists and Todos + */ +model Space { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + name String @length(4, 50) + slug String @unique @length(4, 16) + owner User? @relation(fields: [ownerId], references: [id]) + ownerId String? + members SpaceUser[] + lists List[] + + // require login + @@deny('all', auth() == null) + + // everyone can create a space + @@allow('create', true) + + // any user in the space can read the space + @@allow('read', members?[user == auth()]) + + // space admin can update and delete + @@allow('update,delete', members?[user == auth() && role == ADMIN]) +} + +/* + * Model representing membership of a user in a space + */ +model SpaceUser { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + spaceId String + user User @relation(fields: [userId], references: [id], onDelete: Cascade) + userId String + role UserRole + @@unique([userId, spaceId]) + + // require login + @@deny('all', auth() == null) + + // space admin can create/update/delete + @@allow('create,update,delete', space.members?[user == auth() && role == ADMIN]) + + // user can read entries for spaces which he's a member of + @@allow('read', space.members?[user == auth()]) +} + +/* + * Model for a user + */ +model User { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + email String @unique @email + password String? @password @omit + emailVerified DateTime? + name String? + ownedSpaces Space[] + spaces SpaceUser[] + image String? @url + lists List[] + todos Todo[] + + // can be created by anyone, even not logged in + @@allow('create', true) + + // can be read by users sharing any space + @@allow('read', spaces?[space.members?[user == auth()]]) + + // full access by oneself + @@allow('all', auth() == this) +} + +/* + * Model for a Todo list + */ +model List { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + space Space @relation(fields: [spaceId], references: [id], onDelete: Cascade) + spaceId String + owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade) + ownerId String + title String @length(1, 100) + private Boolean @default(false) + todos Todo[] + + // require login + @@deny('all', auth() == null) + + // can be read by owner or space members (only if not private) + @@allow('read', owner == auth() || (space.members?[user == auth()] && !private)) + + // when create, owner must be set to current user, and user must be in the space + @@allow('create', owner == auth() && space.members?[user == auth()]) + + // when create, owner must be set to current user, and user must be in the space + // update is not allowed to change owner + @@allow('update', owner == auth()&& space.members?[user == auth()] && future().owner == owner) + + // can be deleted by owner + @@allow('delete', owner == auth()) +} + +/* + * Model for a single Todo + */ +model Todo { + id String @id @default(uuid()) + createdAt DateTime @default(now()) + updatedAt DateTime @updatedAt + owner User @relation(fields: [ownerId], references: [id], onDelete: Cascade) + ownerId String + list List @relation(fields: [listId], references: [id], onDelete: Cascade) + listId String + title String @length(1, 100) + completedAt DateTime? + + // require login + @@deny('all', auth() == null) + + // owner has full access, also space members have full access (if the parent List is not private) + @@allow('all', list.owner == auth()) + @@allow('all', list.space.members?[user == auth()] && !list.private) + + // update is not allowed to change owner + @@deny('update', future().owner != owner) +} diff --git a/tests/integration/utils/jest-ext.ts b/tests/integration/utils/jest-ext.ts index 649ba104c..ee24741a5 100644 --- a/tests/integration/utils/jest-ext.ts +++ b/tests/integration/utils/jest-ext.ts @@ -1,8 +1,12 @@ import { format } from 'util'; import { isPrismaClientKnownRequestError } from '@zenstackhq/runtime'; +function isPromise(value: any) { + return typeof value.then === 'function' && typeof value.catch === 'function'; +} + export const toBeRejectedByPolicy = async function (received: Promise, expectedMessages?: string[]) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -28,7 +32,7 @@ export const toBeRejectedByPolicy = async function (received: Promise, }; export const toBeNotFound = async function (received: Promise) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -43,7 +47,7 @@ export const toBeNotFound = async function (received: Promise) { }; export const toBeRejectedWithCode = async function (received: Promise, code: string) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -58,7 +62,7 @@ export const toBeRejectedWithCode = async function (received: Promise, }; export const toResolveTruthy = async function (received: Promise) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -83,7 +87,7 @@ export const toResolveTruthy = async function (received: Promise) { }; export const toResolveFalsy = async function (received: Promise) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -108,7 +112,7 @@ export const toResolveFalsy = async function (received: Promise) { }; export const toResolveNull = async function (received: Promise) { - if (!(received instanceof Promise)) { + if (!isPromise(received)) { return { message: () => 'a promise is expected', pass: false }; } try { @@ -135,14 +139,14 @@ export const toResolveNull = async function (received: Promise) { function expectPrismaCode(err: any, code: string) { if (!isPrismaClientKnownRequestError(err)) { return { - message: () => `expected PrismaClientKnownRequestError', got ${err}`, + message: () => `expected PrismaClientKnownRequestError, got ${err}`, pass: false, }; } const errCode = err.code; if (errCode !== code) { return { - message: () => `expected PrismaClientKnownRequestError.code 'P2004', got ${errCode ?? err}`, + message: () => `expected PrismaClientKnownRequestError.code '${code}', got '${errCode ?? err}'`, pass: false, }; } From ac688a47fb5a5c6f241f7e8ef88d20233d314849 Mon Sep 17 00:00:00 2001 From: Yiming Date: Sat, 5 Aug 2023 12:25:55 +0800 Subject: [PATCH 4/5] refactor: remove the usage of zenstack_guard field (#615) --- .../src/enhancements/policy/handler.ts | 51 ++--- .../src/enhancements/policy/policy-utils.ts | 174 +++++++++++------- .../access-policy/expression-writer.ts | 113 ++++++------ .../access-policy/policy-guard-generator.ts | 3 +- .../tests/generator/expression-writer.test.ts | 63 ++++--- .../with-policy/deep-nested.test.ts | 2 +- .../with-policy/nested-to-one.test.ts | 3 +- .../with-policy/post-update.test.ts | 3 +- .../enhancements/with-policy/refactor.test.ts | 1 - .../with-policy/toplevel-operations.test.ts | 3 +- .../integration/tests/plugins/policy.test.ts | 4 +- 11 files changed, 221 insertions(+), 199 deletions(-) diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index c9cdb98a7..082578ba2 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -534,16 +534,16 @@ export class PolicyProxyHandler implements Pr const postWriteChecks: PostWriteCheckRecord[] = []; // registers a post-update check task - const _registerPostUpdateCheck = async (model: string, where: any, db: Record) => { + const _registerPostUpdateCheck = async (model: string, uniqueFilter: any) => { // both "post-update" rules and Zod schemas require a post-update check if (this.utils.hasAuthGuard(model, 'postUpdate') || this.utils.getZodSchema(model)) { // select pre-update field values let preValue: any; const preValueSelect = this.utils.getPreValueSelect(model); if (preValueSelect && Object.keys(preValueSelect).length > 0) { - preValue = await db[model].findFirst({ where, select: preValueSelect }); + preValue = await db[model].findFirst({ where: uniqueFilter, select: preValueSelect }); } - postWriteChecks.push({ model, operation: 'postUpdate', uniqueFilter: where, preValue }); + postWriteChecks.push({ model, operation: 'postUpdate', uniqueFilter, preValue }); } }; @@ -552,12 +552,7 @@ export class PolicyProxyHandler implements Pr // Instead, handle nested create inside update as an atomic operation that creates an entire // subtree (containing nested creates/connects) - const _create = async ( - model: string, - args: any, - context: NestedWriteVisitorContext, - db: Record - ) => { + const _create = async (model: string, args: any, context: NestedWriteVisitorContext) => { let createData = args; if (context.field?.backLink) { // handles the connection to upstream entity @@ -584,12 +579,7 @@ export class PolicyProxyHandler implements Pr postWriteChecks.push(...checks); }; - const _createMany = async ( - model: string, - args: any, - context: NestedWriteVisitorContext, - db: Record - ) => { + const _createMany = async (model: string, args: any, context: NestedWriteVisitorContext) => { if (context.field?.backLink) { // handles the connection to upstream entity const reversedQuery = await this.utils.buildReversedQuery(context); @@ -602,12 +592,7 @@ export class PolicyProxyHandler implements Pr postWriteChecks.push(...checks); }; - const _connectDisconnect = async ( - model: string, - args: any, - context: NestedWriteVisitorContext, - db: Record - ) => { + const _connectDisconnect = async (model: string, args: any, context: NestedWriteVisitorContext) => { if (context.field?.backLink) { const backLinkField = this.utils.getModelField(model, context.field.backLink); if (backLinkField.isRelationOwner) { @@ -615,7 +600,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkPolicyForUnique(model, args, 'update', db); // register post-update check - await _registerPostUpdateCheck(model, args, db); + await _registerPostUpdateCheck(model, args); } } }; @@ -669,7 +654,7 @@ export class PolicyProxyHandler implements Pr } // register post-update check - await _registerPostUpdateCheck(model, ids, db); + await _registerPostUpdateCheck(model, ids); } }, @@ -706,7 +691,7 @@ export class PolicyProxyHandler implements Pr create: async (model, args, context) => { // process the entire create subtree separately - await _create(model, args, context, db); + await _create(model, args, context); // remove it from the update payload delete context.parent.create; @@ -717,7 +702,7 @@ export class PolicyProxyHandler implements Pr createMany: async (model, args, context) => { // process createMany separately - await _createMany(model, args, context, db); + await _createMany(model, args, context); // remove it from the update payload delete context.parent.createMany; @@ -739,7 +724,7 @@ export class PolicyProxyHandler implements Pr await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db); // register post-update check - await _registerPostUpdateCheck(model, uniqueFilter, uniqueFilter); + await _registerPostUpdateCheck(model, uniqueFilter); // convert upsert to update context.parent.update = { where: args.where, data: args.update }; @@ -751,7 +736,7 @@ export class PolicyProxyHandler implements Pr // create case // process the entire create subtree separately - await _create(model, args.create, context, db); + await _create(model, args.create, context); // remove it from the update payload delete context.parent.upsert; @@ -761,21 +746,21 @@ export class PolicyProxyHandler implements Pr } }, - connect: async (model, args, context) => _connectDisconnect(model, args, context, db), + connect: async (model, args, context) => _connectDisconnect(model, args, context), connectOrCreate: async (model, args, context) => { // the where condition is already unique, so we can use it to check if the target exists const existing = await this.utils.checkExistence(db, model, args.where); if (existing) { // connect - await _connectDisconnect(model, args.where, context, db); + await _connectDisconnect(model, args.where, context); } else { // create - await _create(model, args.create, context, db); + await _create(model, args.create, context); } }, - disconnect: async (model, args, context) => _connectDisconnect(model, args, context, db), + disconnect: async (model, args, context) => _connectDisconnect(model, args, context), set: async (model, args, context) => { // find the set of items to be replaced @@ -790,10 +775,10 @@ export class PolicyProxyHandler implements Pr const currentSet = await db[model].findMany(findCurrSetArgs); // register current set for update (foreign key) - await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context, db))); + await Promise.all(currentSet.map((item) => _connectDisconnect(model, item, context))); // proceed with connecting the new set - await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context, db))); + await Promise.all(enumerate(args).map((item) => _connectDisconnect(model, item, context))); }, delete: async (model, args, context) => { diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index e83c77454..7508760fa 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -4,7 +4,7 @@ import deepcopy from 'deepcopy'; import { lowerCaseFirst } from 'lower-case-first'; import { upperCaseFirst } from 'upper-case-first'; import { fromZodError } from 'zod-validation-error'; -import { AUXILIARY_FIELDS, CrudFailureReason, GUARD_FIELD_NAME, PrismaErrorCode } from '../../constants'; +import { AUXILIARY_FIELDS, CrudFailureReason, PrismaErrorCode } from '../../constants'; import { AuthUser, DbClientContract, DbOperations, FieldInfo, PolicyOperationKind } from '../../types'; import { getVersion } from '../../version'; import { getFields, resolveField } from '../model-meta'; @@ -45,56 +45,111 @@ export class PolicyUtil { /** * Creates a conjunction of a list of query conditions. */ - and(...conditions: (boolean | object)[]): any { - // TODO: reduction - - if (conditions.includes(false)) { - // always false - return { [GUARD_FIELD_NAME]: false }; - } - - const filtered = conditions.filter( - (c): c is object => typeof c === 'object' && !!c && Object.keys(c).length > 0 - ); - if (filtered.length === 0) { - return undefined; - } else if (filtered.length === 1) { - return filtered[0]; - } else { - return { AND: filtered }; - } + and(...conditions: (boolean | object)[]): object { + return this.reduce({ AND: conditions }); } /** * Creates a disjunction of a list of query conditions. */ - or(...conditions: (boolean | object)[]): any { - // TODO: reduction + or(...conditions: (boolean | object)[]): object { + return this.reduce({ OR: conditions }); + } - if (conditions.includes(true)) { - // always true - return { [GUARD_FIELD_NAME]: true }; + /** + * Creates a negation of a query condition. + */ + not(condition: object | boolean | undefined): object { + if (condition === undefined) { + return this.makeTrue(); + } else if (typeof condition === 'boolean') { + return this.reduce(!condition); + } else { + return this.reduce({ NOT: condition }); } + } - const filtered = conditions.filter((c): c is object => typeof c === 'object' && !!c); - if (filtered.length === 0) { - return undefined; - } else if (filtered.length === 1) { - return filtered[0]; + // Static True/False conditions + // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals + + private isTrue(condition: object) { + if (condition === null || condition === undefined) { + return false; } else { - return { OR: filtered }; + return ( + (typeof condition === 'object' && Object.keys(condition).length === 0) || + ('AND' in condition && Array.isArray(condition.AND) && condition.AND.length === 0) + ); } } - /** - * Creates a negation of a query condition. - */ - not(condition: object | boolean): any { - if (typeof condition === 'boolean') { - return !condition; + private isFalse(condition: object) { + if (condition === null || condition === undefined) { + return false; } else { - return { NOT: condition }; + return 'OR' in condition && Array.isArray(condition.OR) && condition.OR.length === 0; + } + } + + private makeTrue() { + return { AND: [] }; + } + + private makeFalse() { + return { OR: [] }; + } + + private reduce(condition: object | boolean | undefined): object { + if (condition === true || condition === undefined) { + return this.makeTrue(); + } + + if (condition === false) { + return this.makeFalse(); + } + + if ('AND' in condition && Array.isArray(condition.AND)) { + const children = condition.AND.map((c: any) => this.reduce(c)).filter( + (c) => c !== undefined && !this.isTrue(c) + ); + if (children.length === 0) { + return this.makeTrue(); + } else if (children.some((c) => this.isFalse(c))) { + return this.makeFalse(); + } else if (children.length === 1) { + return children[0]; + } else { + return { AND: children }; + } + } + + if ('OR' in condition && Array.isArray(condition.OR)) { + const children = condition.OR.map((c: any) => this.reduce(c)).filter( + (c) => c !== undefined && !this.isFalse(c) + ); + if (children.length === 0) { + return this.makeFalse(); + } else if (children.some((c) => this.isTrue(c))) { + return this.makeTrue(); + } else if (children.length === 1) { + return children[0]; + } else { + return { OR: children }; + } + } + + if ('NOT' in condition && condition.NOT !== null && typeof condition.NOT === 'object') { + const child = this.reduce(condition.NOT); + if (this.isTrue(child)) { + return this.makeFalse(); + } else if (this.isFalse(child)) { + return this.makeTrue(); + } else { + return { NOT: child }; + } } + + return condition; } //#endregion @@ -107,7 +162,7 @@ export class PolicyUtil { * @returns true if operation is unconditionally allowed, false if unconditionally denied, * otherwise returns a guard object */ - getAuthGuard(model: string, operation: PolicyOperationKind, preValue?: any): boolean | object { + getAuthGuard(model: string, operation: PolicyOperationKind, preValue?: any): object { const guard = this.policy.guard[lowerCaseFirst(model)]; if (!guard) { throw this.unknownError(`unable to load policy guard for ${model}`); @@ -115,13 +170,14 @@ export class PolicyUtil { const provider: PolicyFunc | boolean | undefined = guard[operation]; if (typeof provider === 'boolean') { - return provider; + return this.reduce(provider); } if (!provider) { throw this.unknownError(`zenstack: unable to load authorization guard for ${model}`); } - return provider({ user: this.user, preValue }); + const r = provider({ user: this.user, preValue }); + return this.reduce(r); } /** @@ -165,10 +221,8 @@ export class PolicyUtil { */ async injectAuthGuard(args: any, model: string, operation: PolicyOperationKind) { const guard = this.getAuthGuard(model, operation); - if (guard === false) { - // use OR with 0 filters to represent filtering out everything - // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals - args.where = { OR: [] }; + if (this.isFalse(guard)) { + args.where = this.makeFalse(); return false; } @@ -179,15 +233,7 @@ export class PolicyUtil { await this.injectGuardForRelationFields(model, args.where, operation); } - const combined = this.and(args.where, guard); - if (combined !== undefined) { - args.where = combined; - } else { - // use AND with 0 filters to represent no filtering - // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals - args.where = { AND: [] }; - } - + args.where = this.and(args.where, guard); return true; } @@ -286,7 +332,7 @@ export class PolicyUtil { await this.injectGuardForRelationFields(model, args.where, 'read'); } - if (injected.where && Object.keys(injected.where).length > 0) { + if (injected.where && Object.keys(injected.where).length > 0 && !this.isTrue(injected.where)) { args.where = args.where ?? {}; Object.assign(args.where, injected.where); } @@ -445,11 +491,7 @@ export class PolicyUtil { await this.injectAuthGuard(injectTarget[field], fieldInfo.type, 'read'); } else { // hoist non-nullable to-one filter to the parent level - const guard = this.getAuthGuard(fieldInfo.type, 'read'); - if (guard !== true) { - // use "and" to resolve boolean values - hoisted = this.and(guard); - } + hoisted = this.getAuthGuard(fieldInfo.type, 'read'); } // recurse @@ -459,7 +501,7 @@ export class PolicyUtil { hoisted = this.and(hoisted, ...subHoisted); } - if (hoisted !== undefined) { + if (hoisted && !this.isTrue(hoisted)) { hoistedConditions.push({ [field]: hoisted }); } } @@ -479,14 +521,14 @@ export class PolicyUtil { preValue?: any ) { const guard = this.getAuthGuard(model, operation, preValue); - if (guard === false) { + if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation, `entity ${formatObject(uniqueFilter)} failed policy check`); } // Zod schema is to be checked for "create" and "postUpdate" const schema = ['create', 'postUpdate'].includes(operation) ? this.getZodSchema(model) : undefined; - if (guard === true && !schema) { + if (this.isTrue(guard) && !schema) { // unconditionally allowed return; } @@ -502,9 +544,7 @@ export class PolicyUtil { this.flattenGeneratedUniqueField(model, where); // query with policy guard - if (guard !== true) { - where = this.and(where, guard); - } + where = this.and(where, guard); const query = { select, where }; if (this.shouldLogQuery) { @@ -538,7 +578,7 @@ export class PolicyUtil { */ tryReject(model: string, operation: PolicyOperationKind) { const guard = this.getAuthGuard(model, operation); - if (guard === false) { + if (this.isFalse(guard)) { throw this.deniedByPolicy(model, operation); } } @@ -594,7 +634,7 @@ export class PolicyUtil { } if (this.shouldLogQuery) { - this.logger.info(`[policy] \`findFirst\` ${model}:\n${formatObject(readArgs)}`); + this.logger.info(`[policy] checking read-back, \`findFirst\` ${model}:\n${formatObject(readArgs)}`); } const result = await db[model].findFirst(readArgs); if (!result) { diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index be0cf25a5..e45bce0cb 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -14,13 +14,7 @@ import { ReferenceExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { - ExpressionContext, - getFunctionExpressionContext, - getLiteral, - GUARD_FIELD_NAME, - PluginError, -} from '@zenstackhq/sdk'; +import { ExpressionContext, getFunctionExpressionContext, getLiteral, PluginError } from '@zenstackhq/sdk'; import { CodeBlockWriter } from 'ts-morph'; import { name } from '.'; import { getIdFields, isAuthInvocation } from '../../utils/ast-utils'; @@ -47,6 +41,11 @@ type FilterOperators = | 'hasSome' | 'isEmpty'; +// { OR: [] } filters to nothing, { AND: [] } includes everything +// https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals +const TRUE = '{ AND: [] }'; +const FALSE = '{ OR: [] }'; + /** * Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer */ @@ -110,19 +109,19 @@ export class ExpressionWriter { } private writeMemberAccess(expr: MemberAccessExpr) { - this.block(() => { - if (this.isAuthOrAuthMemberAccess(expr)) { - // member access of `auth()`, generate plain expression - this.guard(() => this.plain(expr), true); - } else { + if (this.isAuthOrAuthMemberAccess(expr)) { + // member access of `auth()`, generate plain expression + this.guard(() => this.plain(expr), true); + } else { + this.block(() => { // must be a boolean member this.writeFieldCondition(expr.operand, () => { this.block(() => { this.writer.write(`${expr.member.ref?.name}: true`); }); }); - } - }); + }); + } } private writeExprList(exprs: Expression[]) { @@ -168,33 +167,35 @@ export class ExpressionWriter { const leftIsFieldAccess = this.isFieldAccess(expr.left); const rightIsFieldAccess = this.isFieldAccess(expr.right); - this.block(() => { - if (!leftIsFieldAccess && !rightIsFieldAccess) { - // 'in' without referencing fields - this.guard(() => this.plain(expr)); - } else if (leftIsFieldAccess && !rightIsFieldAccess) { - // 'in' with left referencing a field, right is an array literal - this.writeFieldCondition( - expr.left, - () => { - this.plain(expr.right); - }, - 'in' - ); - } else if (!leftIsFieldAccess && rightIsFieldAccess) { - // 'in' with right referencing an array field, left is a literal - // transform it into a 'has' filter - this.writeFieldCondition( - expr.right, - () => { - this.plain(expr.left); - }, - 'has' - ); - } else { - throw new PluginError(name, '"in" operator cannot be used with field references on both sides'); - } - }); + if (!leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' without referencing fields + this.guard(() => this.plain(expr)); + } else { + this.block(() => { + if (leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' with left referencing a field, right is an array literal + this.writeFieldCondition( + expr.left, + () => { + this.plain(expr.right); + }, + 'in' + ); + } else if (!leftIsFieldAccess && rightIsFieldAccess) { + // 'in' with right referencing an array field, left is a literal + // transform it into a 'has' filter + this.writeFieldCondition( + expr.right, + () => { + this.plain(expr.left); + }, + 'has' + ); + } else { + throw new PluginError(name, '"in" operator cannot be used with field references on both sides'); + } + }); + } } private writeCollectionPredicate(expr: BinaryExpr, operator: string) { @@ -228,14 +229,14 @@ export class ExpressionWriter { return false; } - private guard(write: () => void, cast = false) { - this.writer.write(`${GUARD_FIELD_NAME}: `); + private guard(condition: () => void, cast = false) { if (cast) { this.writer.write('!!'); - write(); + condition(); } else { - write(); + condition(); } + this.writer.write(` ? ${TRUE} : ${FALSE}`); } private plain(expr: Expression) { @@ -260,10 +261,8 @@ export class ExpressionWriter { if (!leftIsFieldAccess && !rightIsFieldAccess) { // compile down to a plain expression - this.block(() => { - this.guard(() => { - this.plain(expr); - }); + this.guard(() => { + this.plain(expr); }); return; } @@ -294,11 +293,11 @@ export class ExpressionWriter { if (this.isAuthOrAuthMemberAccess(operand) && !fieldAccess.$resolvedType?.nullable) { try { this.writer.write( - `(${this.plainExprBuilder.transform(operand)} == null) ? { ${GUARD_FIELD_NAME}: ${ + `(${this.plainExprBuilder.transform(operand)} == null) ? ${ // auth().x != user.x is true when auth().x is null and user is not nullable // other expressions are evaluated to false when null is involved - operator === '!=' ? 'true' : 'false' - } } : ` + operator === '!=' ? TRUE : FALSE + } : ` ); } catch (err) { if (err instanceof TypeScriptExpressionTransformerError) { @@ -555,11 +554,15 @@ export class ExpressionWriter { } private writeLiteral(expr: LiteralExpr) { - this.block(() => { + if (expr.value === true) { + this.writer.write(TRUE); + } else if (expr.value === false) { + this.writer.write(FALSE); + } else { this.guard(() => { this.plain(expr); }); - }); + } } private writeInvocation(expr: InvocationExpr) { @@ -575,7 +578,7 @@ export class ExpressionWriter { ) { if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) { // filter functions without referencing fields - this.block(() => this.guard(() => this.plain(expr))); + this.guard(() => this.plain(expr)); return; } diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 6102528a6..4e46072a1 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -23,7 +23,6 @@ import { getDataModels, getLiteral, getPrismaClientImportSpec, - GUARD_FIELD_NAME, hasAttribute, hasValidationAttributes, PluginError, @@ -466,7 +465,7 @@ export default class PolicyGenerator { writeAllows(); } else { // disallow any operation - writer.write(`{ ${GUARD_FIELD_NAME}: false }`); + writer.write(`{ OR: [] }`); } writer.write(';'); }); diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index e307c73b8..e35f07269 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -1,7 +1,6 @@ /// import { DataModel, Enum, Expression, isDataModel, isEnum } from '@zenstackhq/language/ast'; -import { GUARD_FIELD_NAME } from '@zenstackhq/sdk'; import * as tmp from 'tmp'; import { Project, VariableDeclarationKind } from 'ts-morph'; import { ExpressionWriter } from '../../src/plugins/access-policy/expression-writer'; @@ -17,7 +16,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ ${GUARD_FIELD_NAME}: true }` + `{ AND: [] }` ); await check( @@ -28,7 +27,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ ${GUARD_FIELD_NAME}: false }` + `{ OR: [] }` ); }); @@ -121,7 +120,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user == null) ? { zenstack_guard: false } : { id: user.id }` + `(user == null) ? { OR: [] } : { id: user.id }` ); await check( @@ -133,7 +132,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user == null) ? { zenstack_guard: true } : { NOT: { id: user.id } }` + `(user == null) ? { AND: [] } : { NOT: { id: user.id } }` ); await check( @@ -537,7 +536,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ zenstack_guard: (user == null) }`, + `(user==null)?{AND:[]}:{OR:[]}`, '{ id: "1" }' ); @@ -555,7 +554,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ zenstack_guard: (user == null) }`, + `(user==null)?{AND:[]}:{OR:[]}`, '{ x: "1", y: "2" }' ); @@ -571,7 +570,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ zenstack_guard: (user != null) }`, + `(user!=null)?{AND:[]}:{OR:[]}`, '{ id: "1" }' ); @@ -589,7 +588,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ zenstack_guard: (user != null) }`, + `(user!=null)?{AND:[]}:{OR:[]}`, '{ x: "1", y: "2" }' ); }); @@ -608,7 +607,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ zenstack_guard: !!(user?.admin ?? null) }`, + `!!(user?.admin??null)?{AND:[]}:{OR:[]}`, '{ id: "1", admin: true }' ); @@ -625,7 +624,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{ NOT: { zenstack_guard: !!(user?.admin ?? null) } }`, + `{ NOT: !!(user?.admin??null)?{AND:[]}:{OR:[]} }`, '{ id: "1", admin: true }' ); }); @@ -646,7 +645,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user==null) ? { zenstack_guard: false } : { owner: { is: { id : user.id } } }` + `(user==null) ? { OR: [] } : { owner: { is: { id : user.id } } }` ); await check( @@ -664,7 +663,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `(user==null) ? { zenstack_guard: true } : + `(user==null) ? { AND: [] } : { owner: { isNot: { id: user.id } @@ -688,7 +687,7 @@ describe('Expression Writer Tests', () => { `, (model) => model.attributes[0].args[1].value, `((user?.id??null)==null) ? - { zenstack_guard : false } : + { OR: [] } : { owner: { id: { equals: (user?.id ?? null) } } }` ); }); @@ -714,7 +713,7 @@ describe('Expression Writer Tests', () => { `, (model) => model.attributes[1].args[1].value, `(user==null) ? - { zenstack_guard: false } : + { OR: [] } : { owner: { is: { x: user.x, y: user.y } } }`, '{ x: "1", y: "2" }' ); @@ -739,7 +738,7 @@ describe('Expression Writer Tests', () => { `, (model) => model.attributes[1].args[1].value, `(user==null) ? - { zenstack_guard: true } : + { AND: [] } : { owner: { isNot: { x: user.x, y: user.y } } }`, '{ x: "1", y: "2" }' ); @@ -765,8 +764,8 @@ describe('Expression Writer Tests', () => { (model) => model.attributes[1].args[1].value, `{ AND: [ - ((user?.x??null)==null) ? { zenstack_guard: false } : { owner: { x: { equals: (user?.x ?? null) } } }, - ((user?.y??null)==null) ? { zenstack_guard: false } : { owner: { y: { equals: (user?.y ?? null) } } } + ((user?.x??null)==null) ? { OR: [] } : { owner: { x: { equals: (user?.x ?? null) } } }, + ((user?.y??null)==null) ? { OR: [] } : { owner: { y: { equals: (user?.y ?? null) } } } ] }`, '{ x: "1", y: "2" }' @@ -833,7 +832,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `((user?.id??null)==null) ? { zenstack_guard: false } : { owner: { id: { equals: (user?.id ?? null) } } }` + `((user?.id??null)==null) ? { OR: [] } : { owner: { id: { equals: (user?.id ?? null) } } }` ); }); @@ -858,8 +857,8 @@ describe('Expression Writer Tests', () => { AND: [ { AND: [ - { zenstack_guard: (user!=null) }, - ((user?.id??null)==null) ? {zenstack_guard:false} : { owner: { id: { equals: (user?.id??null) } } } + (user!=null)?{AND:[]}:{OR:[]}, + ((user?.id??null)==null) ? {OR:[]} : { owner: { id: { equals: (user?.id??null) } } } ] }, { value: { gt: 0 } } @@ -887,8 +886,8 @@ describe('Expression Writer Tests', () => { OR: [ { OR: [ - { zenstack_guard:(user==null) }, - ((user?.id??null)==null) ? {zenstack_guard:true} : { owner : { id: { not: { equals: (user?.id??null) } } } } + (user==null)?{AND:[]}:{OR:[]}, + ((user?.id??null)==null) ? {AND:[]} : { owner : { id: { not: { equals: (user?.id??null) } } } } ] }, { value: { lte: 0 } } @@ -1159,7 +1158,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, + `(user?.roles?.includes(Role.ADMIN)??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1186,7 +1185,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.email?.includes('test')??false)}`, + `(user?.email?.includes('test')??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1199,7 +1198,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.email?.toLowerCase().includes('test'?.toLowerCase())??false)}`, + `(user?.email?.toLowerCase().includes('test'?.toLowerCase())??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1212,7 +1211,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.email?.startsWith('test')??false)}`, + `(user?.email?.startsWith('test')??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1225,7 +1224,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.email?.endsWith('test')??false)}`, + `(user?.email?.endsWith('test')??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1238,7 +1237,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, + `(user?.roles?.includes(Role.ADMIN)??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1251,7 +1250,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))??false)}`, + `([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1264,7 +1263,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))??false)}`, + `([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))??false)?{AND:[]}:{OR:[]}`, userInit ); @@ -1277,7 +1276,7 @@ describe('Expression Writer Tests', () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((!user?.roles||user?.roles?.length===0)??false)}`, + `((!user?.roles||user?.roles?.length===0)??false)?{AND:[]}:{OR:[]}`, userInit ); }); diff --git a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts index 3cec4e02d..f2d2aa2ce 100644 --- a/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts +++ b/tests/integration/tests/enhancements/with-policy/deep-nested.test.ts @@ -68,7 +68,7 @@ describe('With Policy:deep nested', () => { }); beforeEach(async () => { - const params = await loadSchema(model, { logPrismaQuery: true }); + const params = await loadSchema(model); db = params.withPolicy(); prisma = params.prisma; }); diff --git a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts index 645617f8c..9d5b9be4b 100644 --- a/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts +++ b/tests/integration/tests/enhancements/with-policy/nested-to-one.test.ts @@ -119,8 +119,7 @@ describe('With Policy:nested to-one', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - `, - { logPrismaQuery: true } + ` ); const db = withPolicy(); diff --git a/tests/integration/tests/enhancements/with-policy/post-update.test.ts b/tests/integration/tests/enhancements/with-policy/post-update.test.ts index e4e45c0be..cc8e3c746 100644 --- a/tests/integration/tests/enhancements/with-policy/post-update.test.ts +++ b/tests/integration/tests/enhancements/with-policy/post-update.test.ts @@ -93,8 +93,7 @@ describe('With Policy: post update', () => { @@allow('create,read', true) @@allow('update', future().value > 1) } - `, - { logPrismaQuery: true } + ` ); const db = withPolicy(); diff --git a/tests/integration/tests/enhancements/with-policy/refactor.test.ts b/tests/integration/tests/enhancements/with-policy/refactor.test.ts index a4666c18e..a8f994298 100644 --- a/tests/integration/tests/enhancements/with-policy/refactor.test.ts +++ b/tests/integration/tests/enhancements/with-policy/refactor.test.ts @@ -28,7 +28,6 @@ describe('With Policy: refactor tests', () => { path.join(__dirname, '../../schema/refactor-pg.zmodel'), { addPrelude: false, - logPrismaQuery: true, } ); getDb = withPolicy; diff --git a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts index 9bc78d302..15626e1c2 100644 --- a/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts +++ b/tests/integration/tests/enhancements/with-policy/toplevel-operations.test.ts @@ -72,8 +72,7 @@ describe('With Policy:toplevel operations', () => { @@allow('create', value > 0) @@allow('update', value > 1) } - `, - { logPrismaQuery: true } + ` ); const db = withPolicy(); diff --git a/tests/integration/tests/plugins/policy.test.ts b/tests/integration/tests/plugins/policy.test.ts index 9b2ab818b..7358a8fb6 100644 --- a/tests/integration/tests/plugins/policy.test.ts +++ b/tests/integration/tests/plugins/policy.test.ts @@ -64,10 +64,10 @@ model M { const { policy } = await loadSchema(model); expect(policy.guard.m.read({ user: undefined })).toEqual( - expect.objectContaining({ AND: [{ zenstack_guard: false }, { value: { gt: 0 } }] }) + expect.objectContaining({ AND: [{ OR: [] }, { value: { gt: 0 } }] }) ); expect(policy.guard.m.read({ user: { id: '1' } })).toEqual( - expect.objectContaining({ AND: [{ zenstack_guard: true }, { value: { gt: 0 } }] }) + expect.objectContaining({ AND: [{ AND: [] }, { value: { gt: 0 } }] }) ); }); }); From 1b7b5bda3f5106d31b7f5e70be27158fb8217600 Mon Sep 17 00:00:00 2001 From: Yiming Date: Sat, 5 Aug 2023 17:58:15 +0800 Subject: [PATCH 5/5] fix: improve consistency of generated guard code (#616) --- package.json | 2 +- packages/language/package.json | 2 +- packages/plugins/openapi/package.json | 2 +- packages/plugins/swr/package.json | 2 +- packages/plugins/tanstack-query/package.json | 2 +- packages/plugins/trpc/package.json | 2 +- packages/runtime/package.json | 2 +- packages/schema/package.json | 2 +- .../access-policy/expression-writer.ts | 4 ++-- .../access-policy/policy-guard-generator.ts | 8 +++---- packages/sdk/package.json | 2 +- packages/server/package.json | 2 +- packages/testtools/package.json | 2 +- pnpm-lock.yaml | 14 +++++++++++ .../integration/tests/plugins/policy.test.ts | 23 +++++++++++-------- 15 files changed, 44 insertions(+), 27 deletions(-) diff --git a/package.json b/package.json index 4c35012c0..cb917a35e 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/language/package.json b/packages/language/package.json index 5c2e8bc97..5a43da152 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/plugins/openapi/package.json b/packages/plugins/openapi/package.json index 175cd84b7..888ad9fde 100644 --- a/packages/plugins/openapi/package.json +++ b/packages/plugins/openapi/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/openapi", "displayName": "ZenStack Plugin and Runtime for OpenAPI", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack plugin and runtime supporting OpenAPI", "main": "index.js", "repository": { diff --git a/packages/plugins/swr/package.json b/packages/plugins/swr/package.json index 083912f6c..496c40008 100644 --- a/packages/plugins/swr/package.json +++ b/packages/plugins/swr/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/swr", "displayName": "ZenStack plugin for generating SWR hooks", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack plugin for generating SWR hooks", "main": "index.js", "repository": { diff --git a/packages/plugins/tanstack-query/package.json b/packages/plugins/tanstack-query/package.json index 130c4b051..bf87033f7 100644 --- a/packages/plugins/tanstack-query/package.json +++ b/packages/plugins/tanstack-query/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/tanstack-query", "displayName": "ZenStack plugin for generating tanstack-query hooks", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack plugin for generating tanstack-query hooks", "main": "index.js", "exports": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 3f5723880..10c66c040 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 7cf48b668..0140101fd 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/schema/package.json b/packages/schema/package.json index 8de326b9a..73d1e5e87 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "A toolkit for building secure CRUD apps with Next.js + Typescript", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index e45bce0cb..3c7fbdd1d 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -43,8 +43,8 @@ type FilterOperators = // { OR: [] } filters to nothing, { AND: [] } includes everything // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals -const TRUE = '{ AND: [] }'; -const FALSE = '{ OR: [] }'; +export const TRUE = '{ AND: [] }'; +export const FALSE = '{ OR: [] }'; /** * Utility for writing ZModel expression as Prisma query argument objects into a ts-morph writer diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 4e46072a1..1fbba6800 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -50,7 +50,7 @@ import { TypeScriptExpressionTransformerError, } from '../../utils/typescript-expression-transformer'; import { ALL_OPERATION_KINDS, getDefaultOutputFolder } from '../plugin-utils'; -import { ExpressionWriter } from './expression-writer'; +import { ExpressionWriter, FALSE, TRUE } from './expression-writer'; import { isFutureExpr } from './utils'; /** @@ -414,10 +414,10 @@ export default class PolicyGenerator { }); try { denies.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return false; }`); + writer.write(`if (${transformer.transform(rule, false)}) { return ${FALSE}; }`); }); allows.forEach((rule) => { - writer.write(`if (${transformer.transform(rule, false)}) { return true; }`); + writer.write(`if (${transformer.transform(rule, false)}) { return ${TRUE}; }`); }); } catch (err) { if (err instanceof TypeScriptExpressionTransformerError) { @@ -426,7 +426,7 @@ export default class PolicyGenerator { throw err; } } - writer.write('return false;'); + writer.write(`return ${FALSE};`); }); } else { statements.push((writer) => { diff --git a/packages/sdk/package.json b/packages/sdk/package.json index d9e744361..a24fa7062 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/server/package.json b/packages/server/package.json index 030f70d3d..cee55fb36 100644 --- a/packages/server/package.json +++ b/packages/server/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/server", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "displayName": "ZenStack Server-side Adapters", "description": "ZenStack server-side adapters", "homepage": "https://zenstack.dev", diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 1836a12ef..f539df544 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.0.0-beta.13", + "version": "1.0.0-beta.15", "description": "ZenStack Test Tools", "main": "index.js", "publishConfig": { diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index a2294dad3..01fdcb97d 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -4990,6 +4990,7 @@ packages: /chownr@1.1.4: resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} + requiresBuild: true dev: true optional: true @@ -5430,6 +5431,7 @@ packages: /decompress-response@6.0.0: resolution: {integrity: sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==} engines: {node: '>=10'} + requiresBuild: true dependencies: mimic-response: 3.1.0 dev: true @@ -5527,6 +5529,7 @@ packages: /detect-libc@2.0.1: resolution: {integrity: sha512-463v3ZeIrcWtdgIg6vI6XUncguvr2TnGl4SzDXinkt9mSLpBJKXT3mW6xT3VQdDN11+WVs29pgvivTc4Lp8v+w==} engines: {node: '>=8'} + requiresBuild: true dev: true optional: true @@ -6393,6 +6396,7 @@ packages: /expand-template@2.0.3: resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} engines: {node: '>=6'} + requiresBuild: true dev: true optional: true @@ -6879,6 +6883,7 @@ packages: /github-from-package@0.0.0: resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} + requiresBuild: true dev: true optional: true @@ -8475,6 +8480,7 @@ packages: /mkdirp-classic@0.5.3: resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} + requiresBuild: true dev: true optional: true @@ -8537,6 +8543,7 @@ packages: /napi-build-utils@1.0.2: resolution: {integrity: sha512-ONmRUqK7zj7DWX0D9ADe03wbwOBZxNAfF20PlGfCWQcD3+/MakShIHrMqx9YwPTfxDdF1zLeL+RGZiR9kGMLdg==} + requiresBuild: true dev: true optional: true @@ -8699,6 +8706,7 @@ packages: /node-abi@3.45.0: resolution: {integrity: sha512-iwXuFrMAcFVi/ZoZiqq8BzAdsLw9kxDfTC0HMyjXfSL/6CSDAGD5UmR7azrAgWV1zKYq7dUUMj4owusBWKLsiQ==} engines: {node: '>=10'} + requiresBuild: true dependencies: semver: 7.5.3 dev: true @@ -8706,6 +8714,7 @@ packages: /node-addon-api@4.3.0: resolution: {integrity: sha512-73sE9+3UaLYYFmDsFZnqCInzPyh3MqIwZO9cw58yIqAZhONrrabrYyYe3TuIqtIiOuTXVhsGau8hcrhhwSsDIQ==} + requiresBuild: true dev: true optional: true @@ -9344,6 +9353,7 @@ packages: resolution: {integrity: sha512-jAXscXWMcCK8GgCoHOfIr0ODh5ai8mj63L2nWrjuAgXE6tDyYGnx4/8o/rCgU+B4JSyZBKbeZqzhtwtC3ovxjw==} engines: {node: '>=10'} hasBin: true + requiresBuild: true dependencies: detect-libc: 2.0.1 expand-template: 2.0.3 @@ -9470,6 +9480,7 @@ packages: /pump@3.0.0: resolution: {integrity: sha512-LwZy+p3SFs1Pytd/jYct4wpv49HiYCqd9Rlc5ZVdk0V+8Yzv6jR5Blk3TRmPL1ft69TxP0IMZGJ+WPFU2BFhww==} + requiresBuild: true dependencies: end-of-stream: 1.4.4 once: 1.4.0 @@ -10035,11 +10046,13 @@ packages: /simple-concat@1.0.1: resolution: {integrity: sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==} + requiresBuild: true dev: true optional: true /simple-get@4.0.1: resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} + requiresBuild: true dependencies: decompress-response: 6.0.0 once: 1.4.0 @@ -10497,6 +10510,7 @@ packages: /tar-fs@2.1.1: resolution: {integrity: sha512-V0r2Y9scmbDRLCNex/+hYzvp/zyYjvFbHPNgVTKfQvVrb6guiE/fxP+XblDNR011utopbkex2nM4dHNV6GDsng==} + requiresBuild: true dependencies: chownr: 1.1.4 mkdirp-classic: 0.5.3 diff --git a/tests/integration/tests/plugins/policy.test.ts b/tests/integration/tests/plugins/policy.test.ts index 7358a8fb6..4b67dae7c 100644 --- a/tests/integration/tests/plugins/policy.test.ts +++ b/tests/integration/tests/plugins/policy.test.ts @@ -13,6 +13,9 @@ describe('Policy plugin tests', () => { process.chdir(origDir); }); + const TRUE = { AND: [] }; + const FALSE = { OR: [] }; + it('short-circuit', async () => { const model = ` model User { @@ -33,18 +36,18 @@ model M { const { policy } = await loadSchema(model); - expect(policy.guard.m.read({ user: undefined })).toEqual(false); - expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(true); + expect(policy.guard.m.read({ user: undefined })).toEqual(FALSE); + expect(policy.guard.m.read({ user: { id: '1' } })).toEqual(TRUE); - expect(policy.guard.m.create({ user: undefined })).toEqual(false); - expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(false); - expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(false); - expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(true); + expect(policy.guard.m.create({ user: undefined })).toEqual(FALSE); + expect(policy.guard.m.create({ user: { id: '1' } })).toEqual(FALSE); + expect(policy.guard.m.create({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect(policy.guard.m.create({ user: { id: '1', value: 1 } })).toEqual(TRUE); - expect(policy.guard.m.update({ user: undefined })).toEqual(false); - expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(false); - expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(false); - expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(true); + expect(policy.guard.m.update({ user: undefined })).toEqual(FALSE); + expect(policy.guard.m.update({ user: { id: '1' } })).toEqual(FALSE); + expect(policy.guard.m.update({ user: { id: '1', value: 0 } })).toEqual(FALSE); + expect(policy.guard.m.update({ user: { id: '1', value: 1 } })).toEqual(TRUE); }); it('no short-circuit', async () => {