Skip to content

feat: field-level policy override #889

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion packages/runtime/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,20 @@ export const FIELD_LEVEL_READ_CHECKER_PREFIX = 'readFieldCheck$';
*/
export const FIELD_LEVEL_READ_CHECKER_SELECTOR = 'readFieldSelect';

/**
* Prefix for field-level override read guard function name
*/
export const FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX = 'readFieldGuardOverride$';

/**
* Prefix for field-level update guard function name
*/
export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldCheck$';
export const FIELD_LEVEL_UPDATE_GUARD_PREFIX = 'updateFieldGuard$';

/**
* Prefix for field-level override update guard function name
*/
export const FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX = 'updateFieldGuardOverride$';

/**
* Flag that indicates if the model has field-level access control
Expand Down
190 changes: 161 additions & 29 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import { ZodError } from 'zod';
import { fromZodError } from 'zod-validation-error';
import {
CrudFailureReason,
FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX,
FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX,
FIELD_LEVEL_READ_CHECKER_PREFIX,
FIELD_LEVEL_READ_CHECKER_SELECTOR,
FIELD_LEVEL_UPDATE_GUARD_PREFIX,
Expand Down Expand Up @@ -236,12 +238,7 @@ export class PolicyUtil {
* @returns true if operation is unconditionally allowed, false if unconditionally denied,
* otherwise returns a guard object
*/
getAuthGuard(
db: Record<string, DbOperations>,
model: string,
operation: PolicyOperationKind,
preValue?: any
): object {
getAuthGuard(db: Record<string, DbOperations>, model: string, operation: PolicyOperationKind, preValue?: any) {
const guard = this.policy.guard[lowerCaseFirst(model)];
if (!guard) {
throw this.unknownError(`unable to load policy guard for ${model}`);
Expand All @@ -260,23 +257,61 @@ export class PolicyUtil {
}

/**
* Get field-level auth guard
* Get field-level read auth guard that overrides the model-level
*/
getFieldUpdateAuthGuard(db: Record<string, DbOperations>, model: string, field: string): object {
const guard = this.policy.guard[lowerCaseFirst(model)];
if (!guard) {
throw this.unknownError(`unable to load policy guard for ${model}`);
getFieldOverrideReadAuthGuard(db: Record<string, DbOperations>, model: string, field: string) {
const guard = this.requireGuard(model);

const provider = guard[`${FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${field}`];
if (provider === undefined) {
// field access is denied by default in override mode
return this.makeFalse();
}

const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`];
if (typeof provider === 'boolean') {
return this.reduce(provider);
}

if (!provider) {
const r = provider({ user: this.user }, db);
return this.reduce(r);
}

/**
* Get field-level update auth guard
*/
getFieldUpdateAuthGuard(db: Record<string, DbOperations>, model: string, field: string) {
const guard = this.requireGuard(model);

const provider = guard[`${FIELD_LEVEL_UPDATE_GUARD_PREFIX}${field}`];
if (provider === undefined) {
// field access is allowed by default
return this.makeTrue();
}

if (typeof provider === 'boolean') {
return this.reduce(provider);
}

const r = provider({ user: this.user }, db);
return this.reduce(r);
}

/**
* Get field-level update auth guard that overrides the model-level
*/
getFieldOverrideUpdateAuthGuard(db: Record<string, DbOperations>, model: string, field: string) {
const guard = this.requireGuard(model);

const provider = guard[`${FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${field}`];
if (provider === undefined) {
// field access is denied by default in override mode
return this.makeFalse();
}

if (typeof provider === 'boolean') {
return this.reduce(provider);
}

const r = provider({ user: this.user }, db);
return this.reduce(r);
}
Expand Down Expand Up @@ -322,10 +357,6 @@ export class PolicyUtil {
*/
injectAuthGuard(db: Record<string, DbOperations>, args: any, model: string, operation: PolicyOperationKind) {
let guard = this.getAuthGuard(db, model, operation);
if (this.isFalse(guard)) {
args.where = this.makeFalse();
return false;
}

if (operation === 'update' && args) {
// merge field-level policy guards
Expand All @@ -334,12 +365,32 @@ export class PolicyUtil {
// rejected
args.where = this.makeFalse();
return false;
} else if (fieldUpdateGuard.guard) {
// merge
guard = this.and(guard, fieldUpdateGuard.guard);
} else {
if (fieldUpdateGuard.guard) {
// merge field-level guard
guard = this.and(guard, fieldUpdateGuard.guard);
}

if (fieldUpdateGuard.overrideGuard) {
// merge field-level override guard on the top level
guard = this.or(guard, fieldUpdateGuard.overrideGuard);
}
}
}

if (operation === 'read') {
// merge field-level read override guards
const fieldReadOverrideGuard = this.getFieldReadGuards(db, model, args);
if (fieldReadOverrideGuard) {
guard = this.or(guard, fieldReadOverrideGuard);
}
}

if (this.isFalse(guard)) {
args.where = this.makeFalse();
return false;
}

if (args.where) {
// inject into relation fields:
// to-many: some/none/every
Expand Down Expand Up @@ -441,7 +492,8 @@ export class PolicyUtil {
* Injects auth guard for read operations.
*/
injectForRead(db: Record<string, DbOperations>, model: string, args: any) {
const injected: any = {};
// make select and include visible to the injection
const injected: any = { select: args.select, include: args.include };
if (!this.injectAuthGuard(db, injected, model, 'read')) {
return false;
}
Expand Down Expand Up @@ -701,9 +753,16 @@ export class PolicyUtil {
}"`,
CrudFailureReason.ACCESS_POLICY_VIOLATION
);
} else if (fieldUpdateGuard.guard) {
// merge
guard = this.and(guard, fieldUpdateGuard.guard);
} else {
if (fieldUpdateGuard.guard) {
// merge field-level guard
guard = this.and(guard, fieldUpdateGuard.guard);
}

if (fieldUpdateGuard.overrideGuard) {
// merge field-level override guard
guard = this.or(guard, fieldUpdateGuard.overrideGuard);
}
}
}

Expand Down Expand Up @@ -761,8 +820,33 @@ export class PolicyUtil {
}
}

private getFieldReadGuards(db: Record<string, DbOperations>, model: string, args: { select?: any; include?: any }) {
const allFields = Object.values(getFields(this.modelMeta, model));

// all scalar fields by default
let fields = allFields.filter((f) => !f.isDataModel);

if (args.select) {
// explicitly selected fields
fields = allFields.filter((f) => args.select?.[f.name] === true);
} else if (args.include) {
// included relations
fields.push(...allFields.filter((f) => !fields.includes(f) && args.include[f.name]));
}

if (fields.length === 0) {
// this can happen if only selecting pseudo fields like "_count"
return undefined;
}

const allFieldGuards = fields.map((field) => this.getFieldOverrideReadAuthGuard(db, model, field.name));
return this.and(...allFieldGuards);
}

private getFieldUpdateGuards(db: Record<string, DbOperations>, model: string, args: any) {
const allFieldGuards = [];
const allOverrideFieldGuards = [];

for (const [k, v] of Object.entries<any>(args.data ?? args)) {
if (typeof v === 'undefined') {
continue;
Expand All @@ -778,20 +862,41 @@ export class PolicyUtil {
for (const fk of foreignKeys) {
const fieldGuard = this.getFieldUpdateAuthGuard(db, model, fk);
if (this.isFalse(fieldGuard)) {
return { guard: allFieldGuards, rejectedByField: fk };
return { guard: fieldGuard, rejectedByField: fk };
}

// add field guard
allFieldGuards.push(fieldGuard);

// add field override guard
const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, fk);
allOverrideFieldGuards.push(overrideFieldGuard);
}
}
} else {
const fieldGuard = this.getFieldUpdateAuthGuard(db, model, k);
if (this.isFalse(fieldGuard)) {
return { guard: allFieldGuards, rejectedByField: k };
return { guard: fieldGuard, rejectedByField: k };
}

// add field guard
allFieldGuards.push(fieldGuard);

// add field override guard
const overrideFieldGuard = this.getFieldOverrideUpdateAuthGuard(db, model, k);
allOverrideFieldGuards.push(overrideFieldGuard);
}
}
return { guard: this.and(...allFieldGuards), rejectedByField: undefined };

const allFieldsCombined = this.and(...allFieldGuards);
const allOverrideFieldsCombined =
allOverrideFieldGuards.length !== 0 ? this.and(...allOverrideFieldGuards) : undefined;

return {
guard: allFieldsCombined,
overrideGuard: allOverrideFieldsCombined,
rejectedByField: undefined,
};
}

/**
Expand Down Expand Up @@ -841,7 +946,13 @@ export class PolicyUtil {
): Promise<{ result: unknown; error?: Error }> {
uniqueFilter = this.clone(uniqueFilter);
this.flattenGeneratedUniqueField(model, uniqueFilter);
const readArgs = { select: selectInclude.select, include: selectInclude.include, where: uniqueFilter };

// make sure only select and include are picked
const selectIncludeClean = this.pick(selectInclude, 'select', 'include');
const readArgs = {
...this.clone(selectIncludeClean),
where: uniqueFilter,
};

const error = this.deniedByPolicy(
model,
Expand All @@ -866,7 +977,7 @@ export class PolicyUtil {
return { error, result: undefined };
}

this.postProcessForRead(result, model, selectInclude);
this.postProcessForRead(result, model, selectIncludeClean);
return { result, error: undefined };
}

Expand Down Expand Up @@ -1165,6 +1276,19 @@ export class PolicyUtil {
return value ? deepcopy(value) : {};
}

/**
* Picks properties from an object.
*/
pick<T>(value: T, ...props: (keyof T)[]): Pick<T, (typeof props)[number]> {
const v: any = value;
return props.reduce(function (result, prop) {
if (prop in v) {
result[prop] = v[prop];
}
return result;
}, {} as any);
}

/**
* Gets "id" fields for a given model.
*/
Expand Down Expand Up @@ -1218,5 +1342,13 @@ export class PolicyUtil {
}
}

private requireGuard(model: string) {
const guard = this.policy.guard[lowerCaseFirst(model)];
if (!guard) {
throw this.unknownError(`unable to load policy guard for ${model}`);
}
return guard;
}

//#endregion
}
9 changes: 8 additions & 1 deletion packages/runtime/src/enhancements/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { z } from 'zod';
import {
FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX,
FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX,
FIELD_LEVEL_READ_CHECKER_PREFIX,
FIELD_LEVEL_READ_CHECKER_SELECTOR,
FIELD_LEVEL_UPDATE_GUARD_PREFIX,
Expand Down Expand Up @@ -47,7 +49,12 @@ export type PolicyDef = {
Partial<Record<`${PolicyOperationKind}_input`, InputCheckFunc | boolean>> &
// field-level read checker functions or update guard functions
Record<`${typeof FIELD_LEVEL_READ_CHECKER_PREFIX}${string}`, ReadFieldCheckFunc> &
Record<`${typeof FIELD_LEVEL_UPDATE_GUARD_PREFIX}${string}`, PolicyFunc> & {
Record<
| `${typeof FIELD_LEVEL_OVERRIDE_READ_GUARD_PREFIX}${string}`
| `${typeof FIELD_LEVEL_UPDATE_GUARD_PREFIX}${string}`
| `${typeof FIELD_LEVEL_OVERRIDE_UPDATE_GUARD_PREFIX}${string}`,
PolicyFunc
> & {
// pre-update value selector
[PRE_UPDATE_VALUE_SELECTOR]?: object;
// field-level read checker selector
Expand Down
Loading