Skip to content

Commit ff1e8a5

Browse files
authored
feat: support using collection predicate expression with auth() (#831)
1 parent 93dc7df commit ff1e8a5

File tree

6 files changed

+175
-15
lines changed

6 files changed

+175
-15
lines changed

packages/runtime/src/enhancements/policy/index.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { hasAllFields } from '../../validation';
1010
import { makeProxy } from '../proxy';
1111
import type { CommonEnhancementOptions, PolicyDef, ZodSchemas } from '../types';
1212
import { PolicyProxyHandler } from './handler';
13+
import { Logger } from './logger';
1314

1415
/**
1516
* Context for evaluating access policies
@@ -72,7 +73,8 @@ export function withPolicy<DbClient extends object>(
7273
const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(options?.loadPath);
7374

7475
// validate user context
75-
if (context?.user && _modelMeta.authModel) {
76+
const userContext = context?.user;
77+
if (userContext && _modelMeta.authModel) {
7678
const idFields = getIdFields(_modelMeta, _modelMeta.authModel);
7779
if (
7880
!hasAllFields(
@@ -84,6 +86,16 @@ export function withPolicy<DbClient extends object>(
8486
`Invalid user context: must have valid ID field ${idFields.map((f) => `"${f.name}"`).join(', ')}`
8587
);
8688
}
89+
90+
// validate user context for fields used in policy expressions
91+
const authSelector = _policy.authSelector;
92+
if (authSelector) {
93+
Object.keys(authSelector).forEach((f) => {
94+
if (!(f in userContext)) {
95+
console.warn(`User context does not have field "${f}" used in policy rules`);
96+
}
97+
});
98+
}
8799
}
88100

89101
return makeProxy(

packages/runtime/src/enhancements/types.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,12 @@ export type PolicyDef = {
5656
[HAS_FIELD_LEVEL_POLICY_FLAG]?: boolean;
5757
}
5858
>;
59+
60+
// tracks which models have data validation rules
5961
validation: Record<string, { hasValidation: boolean }>;
62+
63+
// a { select: ... } object for fetching `auth()` fields needed for policy evaluation
64+
authSelector?: object;
6065
};
6166

6267
/**

packages/schema/src/plugins/access-policy/expression-writer.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,12 @@ export class ExpressionWriter {
226226
// check if the operand should be compiled to a relation query
227227
// or a plain expression
228228
const compileToRelationQuery =
229-
(this.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
230-
(!this.isPostGuard && !this.isFutureMemberAccess(expr.left));
229+
// expression rooted to `auth()` is always compiled to plain expression
230+
!this.isAuthOrAuthMemberAccess(expr.left) &&
231+
// `future()` in post-update context
232+
((this.isPostGuard && this.isFutureMemberAccess(expr.left)) ||
233+
// non-`future()` in pre-update context
234+
(!this.isPostGuard && !this.isFutureMemberAccess(expr.left)));
231235

232236
if (compileToRelationQuery) {
233237
this.block(() => {

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ export default class PolicyGenerator {
8181
namedImports: [
8282
{ name: 'type QueryContext' },
8383
{ name: 'type DbOperations' },
84-
{ name: 'hasAllFields' },
8584
{ name: 'allFieldsEqual' },
8685
{ name: 'type PolicyDef' },
8786
],
@@ -104,6 +103,8 @@ export default class PolicyGenerator {
104103
policyMap[model.name] = await this.generateQueryGuardForModel(model, sf);
105104
}
106105

106+
const authSelector = this.generateAuthSelector(models);
107+
107108
sf.addVariableStatement({
108109
declarationKind: VariableDeclarationKind.Const,
109110
declarations: [
@@ -140,6 +141,11 @@ export default class PolicyGenerator {
140141
writer.writeLine(',');
141142
}
142143
});
144+
145+
if (authSelector) {
146+
writer.writeLine(',');
147+
writer.write(`authSelector: ${JSON.stringify(authSelector)}`);
148+
}
143149
});
144150
},
145151
},
@@ -165,6 +171,43 @@ export default class PolicyGenerator {
165171
}
166172
}
167173

174+
// Generates a { select: ... } object to select `auth()` fields used in policy rules
175+
private generateAuthSelector(models: DataModel[]) {
176+
const authRules: Expression[] = [];
177+
178+
models.forEach((model) => {
179+
// model-level rules
180+
const modelPolicyAttrs = model.attributes.filter((attr) =>
181+
['@@allow', '@@deny'].includes(attr.decl.$refText)
182+
);
183+
184+
// field-level rules
185+
const fieldPolicyAttrs = model.fields
186+
.flatMap((f) => f.attributes)
187+
.filter((attr) => ['@allow', '@deny'].includes(attr.decl.$refText));
188+
189+
// all rule expression
190+
const allExpressions = [...modelPolicyAttrs, ...fieldPolicyAttrs]
191+
.filter((attr) => attr.args.length > 1)
192+
.map((attr) => attr.args[1].value);
193+
194+
// collect `auth()` member access
195+
allExpressions.forEach((rule) => {
196+
streamAst(rule).forEach((node) => {
197+
if (isMemberAccessExpr(node) && isAuthInvocation(node.operand)) {
198+
authRules.push(node);
199+
}
200+
});
201+
});
202+
});
203+
204+
if (authRules.length > 0) {
205+
return this.generateSelectForRules(authRules, true);
206+
} else {
207+
return undefined;
208+
}
209+
}
210+
168211
private isEnumReferenced(model: Model, decl: Enum): unknown {
169212
return streamAllContents(model).some((node) => {
170213
if (isDataModelField(node) && node.type.reference?.ref === decl) {
@@ -293,7 +336,7 @@ export default class PolicyGenerator {
293336
result[kind] = guardFunc.getName()!;
294337

295338
if (kind === 'postUpdate') {
296-
const preValueSelect = this.generateSelectForRules(allows, denies);
339+
const preValueSelect = this.generateSelectForRules([...allows, ...denies]);
297340
if (preValueSelect) {
298341
result[PRE_UPDATE_VALUE_SELECTOR] = preValueSelect;
299342
}
@@ -340,7 +383,7 @@ export default class PolicyGenerator {
340383

341384
if (allFieldsAllows.length > 0 || allFieldsDenies.length > 0) {
342385
result[HAS_FIELD_LEVEL_POLICY_FLAG] = true;
343-
const readFieldCheckSelect = this.generateSelectForRules(allFieldsAllows, allFieldsDenies);
386+
const readFieldCheckSelect = this.generateSelectForRules([...allFieldsAllows, ...allFieldsDenies]);
344387
if (readFieldCheckSelect) {
345388
result[FIELD_LEVEL_READ_CHECKER_SELECTOR] = readFieldCheckSelect;
346389
}
@@ -477,7 +520,7 @@ export default class PolicyGenerator {
477520

478521
// generates a "select" object that contains (recursively) fields referenced by the
479522
// given policy rules
480-
private generateSelectForRules(allows: Expression[], denies: Expression[]): object {
523+
private generateSelectForRules(rules: Expression[], forAuthContext = false): object {
481524
// eslint-disable-next-line @typescript-eslint/no-explicit-any
482525
const result: any = {};
483526
const addPath = (path: string[]) => {
@@ -504,6 +547,10 @@ export default class PolicyGenerator {
504547
return [target.name];
505548
}
506549
} else if (isMemberAccessExpr(node)) {
550+
if (forAuthContext && isAuthInvocation(node.operand)) {
551+
return [node.member.$refText];
552+
}
553+
507554
if (isFutureExpr(node.operand)) {
508555
// future().field is not subject to pre-update select
509556
return undefined;
@@ -562,7 +609,7 @@ export default class PolicyGenerator {
562609
}
563610
};
564611

565-
for (const rule of [...allows, ...denies]) {
612+
for (const rule of rules) {
566613
const paths = collectReferencePaths(rule);
567614
paths.forEach((p) => addPath(p));
568615
}
@@ -780,11 +827,7 @@ export default class PolicyGenerator {
780827
}
781828

782829
// normalize user to null to avoid accidentally use undefined in filter
783-
statements.push(
784-
`const user = hasAllFields(context.user, [${userIdFields
785-
.map((f) => "'" + f.name + "'")
786-
.join(', ')}]) ? context.user as any : null;`
787-
);
830+
statements.push(`const user: any = context.user ?? null;`);
788831
}
789832
}
790833
}

packages/schema/src/utils/ast-utils.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,6 @@ export function getAllDeclarationsFromImports(documents: LangiumDocuments, model
153153
return model.declarations.concat(...imports.map((imp) => imp.declarations));
154154
}
155155

156-
export function isCollectionPredicate(expr: Expression): expr is BinaryExpr {
157-
return isBinaryExpr(expr) && ['?', '!', '^'].includes(expr.operator);
156+
export function isCollectionPredicate(node: AstNode): node is BinaryExpr {
157+
return isBinaryExpr(node) && ['?', '!', '^'].includes(node.operator);
158158
}

tests/integration/tests/enhancements/with-policy/auth.test.ts

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,4 +212,100 @@ describe('With Policy: auth() test', () => {
212212
const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' });
213213
await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy();
214214
});
215+
216+
it('collection predicate', async () => {
217+
const { enhance, prisma } = await loadSchema(
218+
`
219+
model User {
220+
id String @id @default(uuid())
221+
posts Post[]
222+
223+
@@allow('all', true)
224+
}
225+
226+
model Post {
227+
id String @id @default(uuid())
228+
title String
229+
published Boolean @default(false)
230+
author User @relation(fields: [authorId], references: [id])
231+
authorId String
232+
comments Comment[]
233+
234+
@@allow('read', true)
235+
@@allow('create', auth().posts?[published && comments![published]])
236+
}
237+
238+
model Comment {
239+
id String @id @default(uuid())
240+
published Boolean @default(false)
241+
post Post @relation(fields: [postId], references: [id])
242+
postId String
243+
244+
@@allow('all', true)
245+
}
246+
`
247+
);
248+
249+
const user = await prisma.user.create({ data: {} });
250+
251+
const createPayload = {
252+
data: { title: 'Post 1', author: { connect: { id: user.id } } },
253+
};
254+
255+
// no post
256+
await expect(enhance({ id: '1' }).post.create(createPayload)).toBeRejectedByPolicy();
257+
258+
// post not published
259+
await expect(
260+
enhance({ id: '1', posts: [{ id: '1', published: false }] }).post.create(createPayload)
261+
).toBeRejectedByPolicy();
262+
263+
// no comments
264+
await expect(
265+
enhance({ id: '1', posts: [{ id: '1', published: true }] }).post.create(createPayload)
266+
).toBeRejectedByPolicy();
267+
268+
// not all comments published
269+
await expect(
270+
enhance({
271+
id: '1',
272+
posts: [
273+
{
274+
id: '1',
275+
published: true,
276+
comments: [
277+
{ id: '1', published: true },
278+
{ id: '2', published: false },
279+
],
280+
},
281+
],
282+
}).post.create(createPayload)
283+
).toBeRejectedByPolicy();
284+
285+
// comments published but parent post is not
286+
await expect(
287+
enhance({
288+
id: '1',
289+
posts: [
290+
{ id: '1', published: false, comments: [{ id: '1', published: true }] },
291+
{ id: '2', published: true },
292+
],
293+
}).post.create(createPayload)
294+
).toBeRejectedByPolicy();
295+
296+
await expect(
297+
enhance({
298+
id: '1',
299+
posts: [
300+
{ id: '1', published: true, comments: [{ id: '1', published: true }] },
301+
{ id: '2', published: false },
302+
],
303+
}).post.create(createPayload)
304+
).toResolveTruthy();
305+
306+
// no comments ("every" evaluates to tru in this case)
307+
await expect(
308+
enhance({ id: '1', posts: [{ id: '1', published: true, comments: [] }] }).post.create(createPayload)
309+
).toResolveTruthy();
310+
});
215311
});

0 commit comments

Comments
 (0)