Skip to content

feat: field-level access control #638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Aug 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ The `zenstack` CLI transpiles the ZModel into a standard Prisma schema, which yo
At runtime, transparent proxies are created around Prisma clients for intercepting queries and mutations to enforce access policies.

```ts
import { withPolicy } from '@zenstackhq/runtime';
import { enhance } from '@zenstackhq/runtime';

// a regular Prisma client
const prisma = new PrismaClient();

async function getPosts(userId: string) {
// create an enhanced Prisma client that has access control enabled
const enhanced = withPolicy(prisma, { user: userId });
const enhanced = enhance(prisma, { user: userId });

// only posts that're visible to the user will be returned
return enhanced.post.findMany();
Expand All @@ -84,14 +84,14 @@ Server adapter packages help you wrap an access-control-enabled Prisma client in
// pages/api/model/[...path].ts

import { requestHandler } from '@zenstackhq/next';
import { withPolicy } from '@zenstackhq/runtime';
import { enhance } from '@zenstackhq/runtime';
import { getSessionUser } from '@lib/auth';
import { prisma } from '@lib/db';

// Mount Prisma-style APIs: "/api/model/post/findMany", "/api/model/post/create", etc.
// Can be configured to provide standard RESTful APIs (using JSON:API) instead.
export default requestHandler({
getPrisma: (req, res) => withPolicy(prisma, { user: getSessionUser(req, res) }),
getPrisma: (req, res) => enhance(prisma, { user: getSessionUser(req, res) }),
});
```

Expand Down
25 changes: 25 additions & 0 deletions packages/runtime/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,28 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer';
* Minimum Prisma version supported
*/
export const PRISMA_MINIMUM_VERSION = '4.8.0';

/**
* Selector function name for fetching pre-update entity values.
*/
export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect';

/**
* Prefix for field-level read checker function name
*/
export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$';

/**
* Field-level access control evaluation selector function name
*/
export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect';

/**
* Prefix for field-level update guard function name
*/
export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldCheck$';

/**
* Flag that indicates if the model has field-level access control
*/
export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy';
41 changes: 28 additions & 13 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,16 +64,19 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}

const origArgs = args;
args = this.utils.clone(args);
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return null;
}

this.utils.injectReadCheckSelect(this.model, args);

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`);
}
const result = await this.modelClient.findUnique(args);
this.utils.postProcessForRead(result);
this.utils.postProcessForRead(result, this.model, origArgs);
return result;
}

Expand All @@ -85,58 +88,70 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
}

const origArgs = args;
args = this.utils.clone(args);
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
throw this.utils.notFound(this.model);
}

this.utils.injectReadCheckSelect(this.model, args);

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`);
}
const result = await this.modelClient.findUniqueOrThrow(args);
this.utils.postProcessForRead(result);
this.utils.postProcessForRead(result, this.model, origArgs);
return result;
}

async findFirst(args: any) {
const origArgs = args;
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return null;
}

this.utils.injectReadCheckSelect(this.model, args);

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`);
}
const result = await this.modelClient.findFirst(args);
this.utils.postProcessForRead(result);
this.utils.postProcessForRead(result, this.model, origArgs);
return result;
}

async findFirstOrThrow(args: any) {
const origArgs = args;
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
throw this.utils.notFound(this.model);
}

this.utils.injectReadCheckSelect(this.model, args);

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`);
}
const result = await this.modelClient.findFirstOrThrow(args);
this.utils.postProcessForRead(result);
this.utils.postProcessForRead(result, this.model, origArgs);
return result;
}

async findMany(args: any) {
const origArgs = args;
args = args ? this.utils.clone(args) : {};
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
return [];
}

this.utils.injectReadCheckSelect(this.model, args);

if (this.shouldLogQuery) {
this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`);
}
const result = await this.modelClient.findMany(args);
this.utils.postProcessForRead(result);
this.utils.postProcessForRead(result, this.model, origArgs);
return result;
}

Expand Down Expand Up @@ -255,7 +270,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
if (backLinkField?.isRelationOwner) {
// the target side of relation owns the relation,
// check if it's updatable
await this.utils.checkPolicyForUnique(model, args.where, 'update', db);
await this.utils.checkPolicyForUnique(model, args.where, 'update', db, args);
}
}

Expand Down Expand Up @@ -300,7 +315,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr

// the target side of relation owns the relation,
// check if it's updatable
await this.utils.checkPolicyForUnique(model, args, 'update', db);
await this.utils.checkPolicyForUnique(model, args, 'update', db, args);
}
}
},
Expand Down Expand Up @@ -597,7 +612,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
const backLinkField = this.utils.getModelField(model, context.field.backLink);
if (backLinkField.isRelationOwner) {
// update happens on the related model, require updatable
await this.utils.checkPolicyForUnique(model, args, 'update', db);
await this.utils.checkPolicyForUnique(model, args, 'update', db, args);

// register post-update check
await _registerPostUpdateCheck(model, args);
Expand Down Expand Up @@ -638,7 +653,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
this.utils.tryReject(db, this.model, 'update');

// check pre-update guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db);
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);

// handles the case where id fields are updated
const ids = this.utils.clone(existing);
Expand Down Expand Up @@ -721,7 +736,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
// update case

// check pre-update guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db);
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);

// register post-update check
await _registerPostUpdateCheck(model, uniqueFilter);
Expand Down Expand Up @@ -789,7 +804,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await this.utils.checkExistence(db, model, uniqueFilter, true);

// check delete guard
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db);
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args);
},

deleteMany: async (model, args, context) => {
Expand Down Expand Up @@ -942,7 +957,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
await this.utils.checkExistence(tx, this.model, args.where, true);

// inject delete guard
await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx);
await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args);

// proceed with the deletion
if (this.shouldLogQuery) {
Expand Down Expand Up @@ -1037,7 +1052,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record<string, DbOperations>) {
await Promise.all(
postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) =>
this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, preValue)
this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue)
)
);
}
Expand Down
Loading