Skip to content

Commit 9a6f39b

Browse files
authored
feat: field-level access control (#638)
1 parent 9a35f88 commit 9a6f39b

28 files changed

+1702
-318
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ The `zenstack` CLI transpiles the ZModel into a standard Prisma schema, which yo
6262
At runtime, transparent proxies are created around Prisma clients for intercepting queries and mutations to enforce access policies.
6363

6464
```ts
65-
import { withPolicy } from '@zenstackhq/runtime';
65+
import { enhance } from '@zenstackhq/runtime';
6666

6767
// a regular Prisma client
6868
const prisma = new PrismaClient();
6969

7070
async function getPosts(userId: string) {
7171
// create an enhanced Prisma client that has access control enabled
72-
const enhanced = withPolicy(prisma, { user: userId });
72+
const enhanced = enhance(prisma, { user: userId });
7373

7474
// only posts that're visible to the user will be returned
7575
return enhanced.post.findMany();
@@ -84,14 +84,14 @@ Server adapter packages help you wrap an access-control-enabled Prisma client in
8484
// pages/api/model/[...path].ts
8585

8686
import { requestHandler } from '@zenstackhq/next';
87-
import { withPolicy } from '@zenstackhq/runtime';
87+
import { enhance } from '@zenstackhq/runtime';
8888
import { getSessionUser } from '@lib/auth';
8989
import { prisma } from '@lib/db';
9090

9191
// Mount Prisma-style APIs: "/api/model/post/findMany", "/api/model/post/create", etc.
9292
// Can be configured to provide standard RESTful APIs (using JSON:API) instead.
9393
export default requestHandler({
94-
getPrisma: (req, res) => withPolicy(prisma, { user: getSessionUser(req, res) }),
94+
getPrisma: (req, res) => enhance(prisma, { user: getSessionUser(req, res) }),
9595
});
9696
```
9797

packages/runtime/src/constants.ts

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,28 @@ export const PRISMA_PROXY_ENHANCER = '$__zenstack_enhancer';
7272
* Minimum Prisma version supported
7373
*/
7474
export const PRISMA_MINIMUM_VERSION = '4.8.0';
75+
76+
/**
77+
* Selector function name for fetching pre-update entity values.
78+
*/
79+
export const PRE_UPDATE_VALUE_SELECTOR = 'preValueSelect';
80+
81+
/**
82+
* Prefix for field-level read checker function name
83+
*/
84+
export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$';
85+
86+
/**
87+
* Field-level access control evaluation selector function name
88+
*/
89+
export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect';
90+
91+
/**
92+
* Prefix for field-level update guard function name
93+
*/
94+
export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldCheck$';
95+
96+
/**
97+
* Flag that indicates if the model has field-level access control
98+
*/
99+
export const HAS_FIELD_LEVEL_POLICY_FLAG = 'hasFieldLevelPolicy';

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -64,16 +64,19 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
6464
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
6565
}
6666

67+
const origArgs = args;
6768
args = this.utils.clone(args);
6869
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
6970
return null;
7071
}
7172

73+
this.utils.injectReadCheckSelect(this.model, args);
74+
7275
if (this.shouldLogQuery) {
7376
this.logger.info(`[policy] \`findUnique\` ${this.model}:\n${formatObject(args)}`);
7477
}
7578
const result = await this.modelClient.findUnique(args);
76-
this.utils.postProcessForRead(result);
79+
this.utils.postProcessForRead(result, this.model, origArgs);
7780
return result;
7881
}
7982

@@ -85,58 +88,70 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
8588
throw prismaClientValidationError(this.prisma, 'where field is required in query argument');
8689
}
8790

91+
const origArgs = args;
8892
args = this.utils.clone(args);
8993
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
9094
throw this.utils.notFound(this.model);
9195
}
9296

97+
this.utils.injectReadCheckSelect(this.model, args);
98+
9399
if (this.shouldLogQuery) {
94100
this.logger.info(`[policy] \`findUniqueOrThrow\` ${this.model}:\n${formatObject(args)}`);
95101
}
96102
const result = await this.modelClient.findUniqueOrThrow(args);
97-
this.utils.postProcessForRead(result);
103+
this.utils.postProcessForRead(result, this.model, origArgs);
98104
return result;
99105
}
100106

101107
async findFirst(args: any) {
108+
const origArgs = args;
102109
args = args ? this.utils.clone(args) : {};
103110
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
104111
return null;
105112
}
106113

114+
this.utils.injectReadCheckSelect(this.model, args);
115+
107116
if (this.shouldLogQuery) {
108117
this.logger.info(`[policy] \`findFirst\` ${this.model}:\n${formatObject(args)}`);
109118
}
110119
const result = await this.modelClient.findFirst(args);
111-
this.utils.postProcessForRead(result);
120+
this.utils.postProcessForRead(result, this.model, origArgs);
112121
return result;
113122
}
114123

115124
async findFirstOrThrow(args: any) {
125+
const origArgs = args;
116126
args = args ? this.utils.clone(args) : {};
117127
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
118128
throw this.utils.notFound(this.model);
119129
}
120130

131+
this.utils.injectReadCheckSelect(this.model, args);
132+
121133
if (this.shouldLogQuery) {
122134
this.logger.info(`[policy] \`findFirstOrThrow\` ${this.model}:\n${formatObject(args)}`);
123135
}
124136
const result = await this.modelClient.findFirstOrThrow(args);
125-
this.utils.postProcessForRead(result);
137+
this.utils.postProcessForRead(result, this.model, origArgs);
126138
return result;
127139
}
128140

129141
async findMany(args: any) {
142+
const origArgs = args;
130143
args = args ? this.utils.clone(args) : {};
131144
if (!(await this.utils.injectForRead(this.prisma, this.model, args))) {
132145
return [];
133146
}
134147

148+
this.utils.injectReadCheckSelect(this.model, args);
149+
135150
if (this.shouldLogQuery) {
136151
this.logger.info(`[policy] \`findMany\` ${this.model}:\n${formatObject(args)}`);
137152
}
138153
const result = await this.modelClient.findMany(args);
139-
this.utils.postProcessForRead(result);
154+
this.utils.postProcessForRead(result, this.model, origArgs);
140155
return result;
141156
}
142157

@@ -255,7 +270,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
255270
if (backLinkField?.isRelationOwner) {
256271
// the target side of relation owns the relation,
257272
// check if it's updatable
258-
await this.utils.checkPolicyForUnique(model, args.where, 'update', db);
273+
await this.utils.checkPolicyForUnique(model, args.where, 'update', db, args);
259274
}
260275
}
261276

@@ -300,7 +315,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
300315

301316
// the target side of relation owns the relation,
302317
// check if it's updatable
303-
await this.utils.checkPolicyForUnique(model, args, 'update', db);
318+
await this.utils.checkPolicyForUnique(model, args, 'update', db, args);
304319
}
305320
}
306321
},
@@ -597,7 +612,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
597612
const backLinkField = this.utils.getModelField(model, context.field.backLink);
598613
if (backLinkField.isRelationOwner) {
599614
// update happens on the related model, require updatable
600-
await this.utils.checkPolicyForUnique(model, args, 'update', db);
615+
await this.utils.checkPolicyForUnique(model, args, 'update', db, args);
601616

602617
// register post-update check
603618
await _registerPostUpdateCheck(model, args);
@@ -638,7 +653,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
638653
this.utils.tryReject(db, this.model, 'update');
639654

640655
// check pre-update guard
641-
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db);
656+
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
642657

643658
// handles the case where id fields are updated
644659
const ids = this.utils.clone(existing);
@@ -721,7 +736,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
721736
// update case
722737

723738
// check pre-update guard
724-
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db);
739+
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'update', db, args);
725740

726741
// register post-update check
727742
await _registerPostUpdateCheck(model, uniqueFilter);
@@ -789,7 +804,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
789804
await this.utils.checkExistence(db, model, uniqueFilter, true);
790805

791806
// check delete guard
792-
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db);
807+
await this.utils.checkPolicyForUnique(model, uniqueFilter, 'delete', db, args);
793808
},
794809

795810
deleteMany: async (model, args, context) => {
@@ -942,7 +957,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
942957
await this.utils.checkExistence(tx, this.model, args.where, true);
943958

944959
// inject delete guard
945-
await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx);
960+
await this.utils.checkPolicyForUnique(this.model, args.where, 'delete', tx, args);
946961

947962
// proceed with the deletion
948963
if (this.shouldLogQuery) {
@@ -1037,7 +1052,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10371052
private async runPostWriteChecks(postWriteChecks: PostWriteCheckRecord[], db: Record<string, DbOperations>) {
10381053
await Promise.all(
10391054
postWriteChecks.map(async ({ model, operation, uniqueFilter, preValue }) =>
1040-
this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, preValue)
1055+
this.utils.checkPolicyForUnique(model, uniqueFilter, operation, db, undefined, preValue)
10411056
)
10421057
);
10431058
}

0 commit comments

Comments
 (0)