Skip to content

feat: implementing access control for Prisma Pulse #643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 58 additions & 3 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'data field is required in query argument');
}

args = this.utils.clone(args);

const { result, error } = await this.transaction(async (tx) => {
// proceed with nested writes and collect post-write checks
const { result, postWriteChecks } = await this.doUpdate(args, tx);
Expand All @@ -543,8 +545,6 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}

private async doUpdate(args: any, db: Record<string, DbOperations>) {
args = this.utils.clone(args);

// collected post-update checks
const postWriteChecks: PostWriteCheckRecord[] = [];

Expand Down Expand Up @@ -903,6 +903,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await this.utils.tryReject(this.prisma, this.model, 'create');
await this.utils.tryReject(this.prisma, this.model, 'update');

args = this.utils.clone(args);

// We can call the native "upsert" because we can't tell if an entity was created or updated
// for doing post-write check accordingly. Instead, decompose it into create or update.

Expand Down Expand Up @@ -998,6 +1000,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'query argument is required');
}

args = this.utils.clone(args);

// inject policy conditions
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

Expand All @@ -1012,6 +1016,8 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'query argument is required');
}

args = this.utils.clone(args);

// inject policy conditions
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

Expand All @@ -1023,7 +1029,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

async count(args: any) {
// inject policy conditions
args = args ?? {};
args = args ? this.utils.clone(args) : {};
await this.utils.injectAuthGuard(this.prisma, args, this.model, 'read');

if (this.shouldLogQuery) {
Expand All @@ -1034,6 +1040,55 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

//#endregion

//#region Subscribe (Prisma Pulse)

async subscribe(args: any) {
const readGuard = this.utils.getAuthGuard(this.prisma, this.model, 'read');
if (this.utils.isTrue(readGuard)) {
// no need to inject
if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
}

if (!args) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
if (typeof args !== 'object') {
throw prismaClientValidationError(this.prisma, 'argument must be an object');
}
if (Object.keys(args).length === 0) {
// include all
args = { create: {}, update: {}, delete: {} };
} else {
args = this.utils.clone(args);
}
}

// inject into subscribe conditions

if (args.create) {
args.create.after = this.utils.and(args.create.after, readGuard);
}

if (args.update) {
args.update.after = this.utils.and(args.update.after, readGuard);
}

if (args.delete) {
args.delete.before = this.utils.and(args.delete.before, readGuard);
}

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`subscribe\` ${this.model}:\n${formatObject(args)}`);
}
return this.modelClient.subscribe(args);
}

//#endregion

//#region Utils

private get shouldLogQuery() {
Expand Down
4 changes: 2 additions & 2 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ export class PolicyUtil {
// Static True/False conditions
// https://www.prisma.io/docs/concepts/components/prisma-client/null-and-undefined#the-effect-of-null-and-undefined-on-conditionals

private isTrue(condition: object) {
public isTrue(condition: object) {
if (condition === null || condition === undefined) {
return false;
} else {
Expand All @@ -92,7 +92,7 @@ export class PolicyUtil {
}
}

private isFalse(condition: object) {
public isFalse(condition: object) {
if (condition === null || condition === undefined) {
return false;
} else {
Expand Down
7 changes: 7 additions & 0 deletions packages/runtime/src/enhancements/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ export interface PrismaProxyHandler {
groupBy(args: any): Promise<unknown>;

count(args: any): Promise<unknown | number>;

subscribe(args: any): Promise<unknown>;
}

/**
Expand Down Expand Up @@ -141,6 +143,11 @@ export class DefaultPrismaProxyHandler implements PrismaProxyHandler {
return this.prisma[this.model].count(args);
}

async subscribe(args: any): Promise<unknown> {
args = await this.preprocessArgs('subscribe', args);
return this.prisma[this.model].subscribe(args);
}

/**
* Processes result entities before they're returned
*/
Expand Down
1 change: 1 addition & 0 deletions packages/runtime/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export interface DbOperations {
aggregate(args: unknown): Promise<unknown>;
groupBy(args: unknown): Promise<unknown>;
count(args?: unknown): Promise<unknown>;
subscribe(args?: unknown): Promise<unknown>;
fields: Record<string, any>;
}

Expand Down
14 changes: 12 additions & 2 deletions packages/testtools/src/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ export type SchemaLoadOptions = {
logPrismaQuery?: boolean;
provider?: 'sqlite' | 'postgresql';
dbUrl?: string;
pulseApiKey?: string;
};

const defaultOptions: SchemaLoadOptions = {
Expand Down Expand Up @@ -187,14 +188,23 @@ export async function loadSchema(schema: string, options?: SchemaLoadOptions) {
run('npx prisma db push');
}

const PrismaClient = require(path.join(projectRoot, 'node_modules/.prisma/client')).PrismaClient;
const prisma = new PrismaClient({ log: ['info', 'warn', 'error'] });
if (opt.pulseApiKey) {
opt.extraDependencies?.push('@prisma/extension-pulse');
}

opt.extraDependencies?.forEach((dep) => {
console.log(`Installing dependency ${dep}`);
run(`npm install ${dep}`);
});

const PrismaClient = require(path.join(projectRoot, 'node_modules/.prisma/client')).PrismaClient;
let prisma = new PrismaClient({ log: ['info', 'warn', 'error'] });

if (opt.pulseApiKey) {
const withPulse = require(path.join(projectRoot, 'node_modules/@prisma/extension-pulse/dist/cjs')).withPulse;
prisma = prisma.$extends(withPulse({ apiKey: opt.pulseApiKey }));
}

if (opt.compile) {
console.log('Compiling...');
run('npx tsc --init');
Expand Down
Loading