diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 78a5d1400..dd3649e55 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -524,6 +524,8 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'data field is required in query argument'); } + args = this.utils.clone(args); + const { result, error } = await this.transaction(async (tx) => { // proceed with nested writes and collect post-write checks const { result, postWriteChecks } = await this.doUpdate(args, tx); @@ -543,8 +545,6 @@ export class PolicyProxyHandler implements Pr } private async doUpdate(args: any, db: Record) { - args = this.utils.clone(args); - // collected post-update checks const postWriteChecks: PostWriteCheckRecord[] = []; @@ -903,6 +903,8 @@ export class PolicyProxyHandler implements Pr await this.utils.tryReject(this.prisma, this.model, 'create'); await this.utils.tryReject(this.prisma, this.model, 'update'); + args = this.utils.clone(args); + // We can call the native "upsert" because we can't tell if an entity was created or updated // for doing post-write check accordingly. Instead, decompose it into create or update. @@ -998,6 +1000,8 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'query argument is required'); } + args = this.utils.clone(args); + // inject policy conditions await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); @@ -1012,6 +1016,8 @@ export class PolicyProxyHandler implements Pr throw prismaClientValidationError(this.prisma, 'query argument is required'); } + args = this.utils.clone(args); + // inject policy conditions await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); @@ -1023,7 +1029,7 @@ export class PolicyProxyHandler implements Pr async count(args: any) { // inject policy conditions - args = args ?? {}; + args = args ? this.utils.clone(args) : {}; await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read'); if (this.shouldLogQuery) { @@ -1034,6 +1040,55 @@ export class PolicyProxyHandler implements Pr //#endregion + //#region Subscribe (Prisma Pulse) + + async subscribe(args: any) { + const readGuard = this.utils.getAuthGuard(this.prisma, this.model, 'read'); + if (this.utils.isTrue(readGuard)) { + // no need to inject + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.subscribe(args); + } + + if (!args) { + // include all + args = { create: {}, update: {}, delete: {} }; + } else { + if (typeof args !== 'object') { + throw prismaClientValidationError(this.prisma, 'argument must be an object'); + } + if (Object.keys(args).length === 0) { + // include all + args = { create: {}, update: {}, delete: {} }; + } else { + args = this.utils.clone(args); + } + } + + // inject into subscribe conditions + + if (args.create) { + args.create.after = this.utils.and(args.create.after, readGuard); + } + + if (args.update) { + args.update.after = this.utils.and(args.update.after, readGuard); + } + + if (args.delete) { + args.delete.before = this.utils.and(args.delete.before, readGuard); + } + + if (this.shouldLogQuery) { + this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`); + } + return this.modelClient.subscribe(args); + } + + //#endregion + //#region Utils private get shouldLogQuery() { diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index cf77fe4b0..f6e69f086 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -81,7 +81,7 @@ export class PolicyUtil { // Static True/False conditions // https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals - private isTrue(condition: object) { + public isTrue(condition: object) { if (condition === null || condition === undefined) { return false; } else { @@ -92,7 +92,7 @@ export class PolicyUtil { } } - private isFalse(condition: object) { + public isFalse(condition: object) { if (condition === null || condition === undefined) { return false; } else { diff --git a/packages/runtime/src/enhancements/proxy.ts b/packages/runtime/src/enhancements/proxy.ts index 43cc36a30..717f63d2e 100644 --- a/packages/runtime/src/enhancements/proxy.ts +++ b/packages/runtime/src/enhancements/proxy.ts @@ -42,6 +42,8 @@ export interface PrismaProxyHandler { groupBy(args: any): Promise; count(args: any): Promise; + + subscribe(args: any): Promise; } /** @@ -141,6 +143,11 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler { return this.prisma[this.model].count(args); } + async subscribe(args: any): Promise { + args = await this.preprocessArgs('subscribe', args); + return this.prisma[this.model].subscribe(args); + } + /** * Processes result entities before they're returned */ diff --git a/packages/runtime/src/types.ts b/packages/runtime/src/types.ts index b796fd0b2..76366d87e 100644 --- a/packages/runtime/src/types.ts +++ b/packages/runtime/src/types.ts @@ -19,6 +19,7 @@ export interface DbOperations { aggregate(args: unknown): Promise; groupBy(args: unknown): Promise; count(args?: unknown): Promise; + subscribe(args?: unknown): Promise; fields: Record; } diff --git a/packages/testtools/src/schema.ts b/packages/testtools/src/schema.ts index 958432fc3..ae81d453e 100644 --- a/packages/testtools/src/schema.ts +++ b/packages/testtools/src/schema.ts @@ -99,6 +99,7 @@ export type SchemaLoadOptions = { logPrismaQuery?: boolean; provider?: 'sqlite' | 'postgresql'; dbUrl?: string; + pulseApiKey?: string; }; const defaultOptions: SchemaLoadOptions = { @@ -187,14 +188,23 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) { run('npx prisma db push'); } - const PrismaClient = require(path.join(projectRoot, 'node_modules/.prisma/client')).PrismaClient; - const prisma = new PrismaClient({ log: ['info', 'warn', 'error'] }); + if (opt.pulseApiKey) { + opt.extraDependencies?.push('@prisma/extension-pulse'); + } opt.extraDependencies?.forEach((dep) => { console.log(`Installing dependency ${dep}`); run(`npm install ${dep}`); }); + const PrismaClient = require(path.join(projectRoot, 'node_modules/.prisma/client')).PrismaClient; + let prisma = new PrismaClient({ log: ['info', 'warn', 'error'] }); + + if (opt.pulseApiKey) { + const withPulse = require(path.join(projectRoot, 'node_modules/@prisma/extension-pulse/dist/cjs')).withPulse; + prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey })); + } + if (opt.compile) { console.log('Compiling...'); run('npx tsc --init'); diff --git a/tests/integration/tests/enhancements/with-policy/subscription.test.ts b/tests/integration/tests/enhancements/with-policy/subscription.test.ts new file mode 100644 index 000000000..2befdd42a --- /dev/null +++ b/tests/integration/tests/enhancements/with-policy/subscription.test.ts @@ -0,0 +1,264 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +const DB_URL = ''; +const PULSE_API_KEY = ''; + +// eslint-disable-next-line jest/no-disabled-tests +describe.skip('With Policy: subscription test', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('subscribe auth check', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model User { + id Int @id @default(autoincrement()) + } + + model Model { + id Int @id @default(autoincrement()) + name String + + @@allow('read', auth() != null) + } + `, + { + provider: 'postgresql', + dbUrl: DB_URL, + pulseApiKey: PULSE_API_KEY, + logPrismaQuery: true, + } + ); + + await prisma.model.deleteMany(); + + const rawSub = await prisma.model.subscribe(); + + const anonDb = withPolicy(); + console.log('Anonymous db subscribing'); + const anonSub = await anonDb.model.subscribe(); + + const authDb = withPolicy({ id: 1 }); + console.log('Auth db subscribing'); + const authSub = await authDb.model.subscribe(); + + async function produce() { + await prisma.model.create({ data: { id: 1, name: 'abc' } }); + console.log('created'); + await prisma.model.update({ where: { id: 1 }, data: { name: 'bcd' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 1 } }); + console.log('deleted'); + await new Promise((resolve) => setTimeout(resolve, 2000)); + } + + const rawEvents: any[] = []; + const authEvents: any[] = []; + const anonEvents: any[] = []; + await Promise.race([ + produce(), + consume(rawSub, 'Raw', rawEvents), + consume(authSub, 'Auth', authEvents), + consume(anonSub, 'Anonymous', anonEvents), + ]); + expect(rawEvents.length).toBe(3); + expect(authEvents.length).toBe(3); + expect(anonEvents.length).toBe(0); + }); + + it('subscribe model check', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + name String + + @@allow('read', contains(name, 'hello')) + } + `, + { + provider: 'postgresql', + dbUrl: DB_URL, + pulseApiKey: PULSE_API_KEY, + logPrismaQuery: true, + } + ); + + await prisma.model.deleteMany(); + + const rawSub = await prisma.model.subscribe(); + + const enhanced = withPolicy(); + console.log('Auth db subscribing'); + const enhancedSub = await enhanced.model.subscribe(); + + async function produce() { + await prisma.model.create({ data: { id: 1, name: 'abc' } }); + console.log('created'); + await prisma.model.update({ where: { id: 1 }, data: { name: 'bcd' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 1 } }); + console.log('deleted'); + + await prisma.model.create({ data: { id: 2, name: 'hello world' } }); + console.log('created'); + await prisma.model.update({ where: { id: 2 }, data: { name: 'hello moon' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 2 } }); + console.log('deleted'); + + await new Promise((resolve) => setTimeout(resolve, 2000)); + } + + const rawEvents: any[] = []; + const enhancedEvents: any[] = []; + await Promise.race([ + produce(), + consume(rawSub, 'Raw', rawEvents), + consume(enhancedSub, 'Enhanced', enhancedEvents), + ]); + expect(rawEvents.length).toBe(6); + expect(enhancedEvents.length).toBe(3); + }); + + it('subscribe partial', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + name String + + @@allow('read', contains(name, 'hello')) + } + `, + { + provider: 'postgresql', + dbUrl: DB_URL, + pulseApiKey: PULSE_API_KEY, + logPrismaQuery: true, + } + ); + + await prisma.model.deleteMany(); + + const rawSub = await prisma.model.subscribe({ create: {} }); + + const enhanced = withPolicy(); + console.log('Auth db subscribing'); + const enhancedSub = await enhanced.model.subscribe({ create: {} }); + + async function produce() { + await prisma.model.create({ data: { id: 1, name: 'abc' } }); + console.log('created'); + await prisma.model.update({ where: { id: 1 }, data: { name: 'bcd' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 1 } }); + console.log('deleted'); + + await prisma.model.create({ data: { id: 2, name: 'hello world' } }); + console.log('created'); + await prisma.model.update({ where: { id: 2 }, data: { name: 'hello moon' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 2 } }); + console.log('deleted'); + + await new Promise((resolve) => setTimeout(resolve, 2000)); + } + + const rawEvents: any[] = []; + const enhancedEvents: any[] = []; + await Promise.race([ + produce(), + consume(rawSub, 'Raw', rawEvents), + consume(enhancedSub, 'Enhanced', enhancedEvents), + ]); + expect(rawEvents.length).toBe(2); + expect(enhancedEvents.length).toBe(1); + }); + + it('subscribe mixed model check', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model Model { + id Int @id @default(autoincrement()) + name String + + @@allow('read', contains(name, 'hello')) + } + `, + { + provider: 'postgresql', + dbUrl: DB_URL, + pulseApiKey: PULSE_API_KEY, + logPrismaQuery: true, + } + ); + + await prisma.model.deleteMany(); + + const rawSub = await prisma.model.subscribe({ + create: { after: { name: { contains: 'world' } } }, + update: { after: { name: { contains: 'world' } } }, + delete: { before: { name: { contains: 'world' } } }, + }); + + const enhanced = withPolicy(); + console.log('Auth db subscribing'); + const enhancedSub = await enhanced.model.subscribe({ + create: { after: { name: { contains: 'world' } } }, + update: { after: { name: { contains: 'world' } } }, + delete: { before: { name: { contains: 'world' } } }, + }); + + async function produce() { + await prisma.model.create({ data: { id: 1, name: 'abc' } }); + console.log('created'); + await prisma.model.update({ where: { id: 1 }, data: { name: 'bcd' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 1 } }); + console.log('deleted'); + + await prisma.model.create({ data: { id: 2, name: 'good world' } }); + console.log('created'); + await prisma.model.update({ where: { id: 2 }, data: { name: 'nice world' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 2 } }); + console.log('deleted'); + + await prisma.model.create({ data: { id: 3, name: 'hello world' } }); + console.log('created'); + await prisma.model.update({ where: { id: 3 }, data: { name: 'hello nice world' } }); + console.log('updated'); + await prisma.model.delete({ where: { id: 3 } }); + console.log('deleted'); + + await new Promise((resolve) => setTimeout(resolve, 2000)); + } + + const rawEvents: any[] = []; + const enhancedEvents: any[] = []; + await Promise.race([ + produce(), + consume(rawSub, 'Raw', rawEvents), + consume(enhancedSub, 'Enhanced', enhancedEvents), + ]); + expect(rawEvents.length).toBe(6); + expect(enhancedEvents.length).toBe(3); + }); +}); + +async function consume(subscription: any, name: string, events: any[]) { + console.log('Consuming', name); + for await (const event of subscription) { + console.log(name, 'got event:', event); + events.push(event); + } +}