Skip to content

Commit ff86ce0

Browse files
committed
fix(delegate): delegate model's guards are not properly including concrete models
fixes #1930
1 parent f609c86 commit ff86ce0

File tree

8 files changed

+319
-59
lines changed

8 files changed

+319
-59
lines changed

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ import {
2424
isArrayExpr,
2525
isDataModel,
2626
isGeneratorDecl,
27-
isReferenceExpr,
2827
isTypeDef,
2928
type Model,
3029
} from '@zenstackhq/sdk/ast';
@@ -45,6 +44,7 @@ import {
4544
} from 'ts-morph';
4645
import { upperCaseFirst } from 'upper-case-first';
4746
import { name } from '..';
47+
import { getConcreteModels, getDiscriminatorField } from '../../../utils/ast-utils';
4848
import { execPackage } from '../../../utils/exec-utils';
4949
import { CorePlugins, getPluginCustomOutputFolder } from '../../plugin-utils';
5050
import { trackPrismaSchemaError } from '../../prisma';
@@ -407,9 +407,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
407407
this.model.declarations
408408
.filter((d): d is DataModel => isDelegateModel(d))
409409
.forEach((dm) => {
410-
const concreteModels = this.model.declarations.filter(
411-
(d): d is DataModel => isDataModel(d) && d.superTypes.some((s) => s.ref === dm)
412-
);
410+
const concreteModels = getConcreteModels(dm);
413411
if (concreteModels.length > 0) {
414412
delegateInfo.push([dm, concreteModels]);
415413
}
@@ -579,7 +577,7 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
579577
const typeName = typeAlias.getName();
580578
const payloadRecord = delegateInfo.find(([delegate]) => `$${delegate.name}Payload` === typeName);
581579
if (payloadRecord) {
582-
const discriminatorDecl = this.getDiscriminatorField(payloadRecord[0]);
580+
const discriminatorDecl = getDiscriminatorField(payloadRecord[0]);
583581
if (discriminatorDecl) {
584582
source = `${payloadRecord[1]
585583
.map(
@@ -826,15 +824,6 @@ export function enhance(prisma: any, context?: EnhancementContext<${authTypePara
826824
.filter((n) => n.getName().startsWith(DELEGATE_AUX_RELATION_PREFIX));
827825
}
828826

829-
private getDiscriminatorField(delegate: DataModel) {
830-
const delegateAttr = getAttribute(delegate, '@@delegate');
831-
if (!delegateAttr) {
832-
return undefined;
833-
}
834-
const arg = delegateAttr.args[0]?.value;
835-
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
836-
}
837-
838827
private saveSourceFile(sf: SourceFile) {
839828
if (this.options.preserveTsFiles) {
840829
saveSourceFile(sf);

packages/schema/src/plugins/enhancer/policy/expression-writer.ts

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -839,16 +839,18 @@ export class ExpressionWriter {
839839
operation = this.options.operationContext;
840840
}
841841

842-
this.block(() => {
843-
if (operation === 'postUpdate') {
844-
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
845-
// e.g.:
846-
// @@allow('all', check(author)) should not delegate "postUpdate" to author
847-
this.writer.write(`${fieldRef.target.$refText}: ${FALSE}`);
848-
} else {
849-
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
850-
this.writer.write(`${fieldRef.target.$refText}: ${targetGuardFunc}(context, db)`);
851-
}
852-
});
842+
this.block(() =>
843+
this.writeFieldCondition(fieldRef, () => {
844+
if (operation === 'postUpdate') {
845+
// 'postUpdate' policies are not delegated to relations, just use constant `false` here
846+
// e.g.:
847+
// @@allow('all', check(author)) should not delegate "postUpdate" to author
848+
this.writer.write(FALSE);
849+
} else {
850+
const targetGuardFunc = getQueryGuardFunctionName(targetModel, undefined, false, operation);
851+
this.writer.write(`${targetGuardFunc}(context, db)`);
852+
}
853+
})
854+
);
853855
}
854856
}

packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts

Lines changed: 36 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import {
2525
hasAttribute,
2626
hasValidationAttributes,
2727
isAuthInvocation,
28+
isDelegateModel,
2829
isForeignKeyField,
2930
saveSourceFile,
3031
} from '@zenstackhq/sdk';
@@ -454,36 +455,44 @@ export class PolicyGenerator {
454455
writer: CodeBlockWriter,
455456
sourceFile: SourceFile
456457
) {
457-
if (kind === 'update' && allows.length === 0) {
458-
// no allow rule for 'update', policy is constant based on if there's
459-
// post-update counterpart
460-
let func: FunctionDeclaration;
461-
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
462-
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
463-
} else {
464-
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
458+
const isDelegate = isDelegateModel(model);
459+
460+
if (!isDelegate) {
461+
// handle cases where a constant function can be used
462+
// note that this doesn't apply to delegate models because
463+
// all concrete models inheriting it need to be considered
464+
465+
if (kind === 'update' && allows.length === 0) {
466+
// no allow rule for 'update', policy is constant based on if there's
467+
// post-update counterpart
468+
let func: FunctionDeclaration;
469+
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
470+
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
471+
} else {
472+
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
473+
}
474+
writer.write(`guard: ${func.getName()!},`);
475+
return;
465476
}
466-
writer.write(`guard: ${func.getName()!},`);
467-
return;
468-
}
469477

470-
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
471-
// no 'postUpdate' rule, always allow
472-
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
473-
writer.write(`guard: ${func.getName()},`);
474-
return;
475-
}
478+
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
479+
// no 'postUpdate' rule, always allow
480+
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
481+
writer.write(`guard: ${func.getName()},`);
482+
return;
483+
}
476484

477-
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
478-
// constant policy
479-
const func = generateConstantQueryGuardFunction(
480-
sourceFile,
481-
model,
482-
kind,
483-
policies[kind as keyof typeof policies] as boolean
484-
);
485-
writer.write(`guard: ${func.getName()!},`);
486-
return;
485+
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
486+
// constant policy
487+
const func = generateConstantQueryGuardFunction(
488+
sourceFile,
489+
model,
490+
kind,
491+
policies[kind as keyof typeof policies] as boolean
492+
);
493+
writer.write(`guard: ${func.getName()!},`);
494+
return;
495+
}
487496
}
488497

489498
// generate a policy function that evaluates a partial prisma query

packages/schema/src/plugins/enhancer/policy/utils.ts

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
2-
import type { PolicyKind, PolicyOperationKind } from '@zenstackhq/runtime';
2+
import { DELEGATE_AUX_RELATION_PREFIX, type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
33
import {
44
ExpressionContext,
55
PluginError,
@@ -15,6 +15,7 @@ import {
1515
getQueryGuardFunctionName,
1616
isAuthInvocation,
1717
isDataModelFieldReference,
18+
isDelegateModel,
1819
isEnumFieldReference,
1920
isFromStdlib,
2021
isFutureExpr,
@@ -39,9 +40,16 @@ import {
3940
} from '@zenstackhq/sdk/ast';
4041
import deepmerge from 'deepmerge';
4142
import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium';
43+
import { lowerCaseFirst } from 'lower-case-first';
4244
import { SourceFile, WriterFunction } from 'ts-morph';
4345
import { name } from '..';
44-
import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils';
46+
import {
47+
getConcreteModels,
48+
getDiscriminatorField,
49+
isCheckInvocation,
50+
isCollectionPredicate,
51+
isFutureInvocation,
52+
} from '../../../utils/ast-utils';
4553
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';
4654

4755
/**
@@ -303,8 +311,11 @@ export function generateQueryGuardFunction(
303311
forField?: DataModelField,
304312
fieldOverride = false
305313
) {
306-
const statements: (string | WriterFunction)[] = [];
314+
if (isDelegateModel(model) && !forField) {
315+
return generateDelegateQueryGuardFunction(sourceFile, model, kind);
316+
}
307317

318+
const statements: (string | WriterFunction)[] = [];
308319
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
309320
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));
310321

@@ -438,6 +449,61 @@ export function generateQueryGuardFunction(
438449
return func;
439450
}
440451

452+
function generateDelegateQueryGuardFunction(sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind) {
453+
const concreteModels = getConcreteModels(model);
454+
455+
const discriminator = getDiscriminatorField(model);
456+
if (!discriminator) {
457+
throw new PluginError(name, `Model '${model.name}' does not have a discriminator field`);
458+
}
459+
460+
const func = sourceFile.addFunction({
461+
name: getQueryGuardFunctionName(model, undefined, false, kind),
462+
returnType: 'any',
463+
parameters: [
464+
{
465+
name: 'context',
466+
type: 'QueryContext',
467+
},
468+
{
469+
// for generating field references used by field comparison in the same model
470+
name: 'db',
471+
type: 'CrudContract',
472+
},
473+
],
474+
statements: (writer) => {
475+
writer.write('return ');
476+
if (concreteModels.length === 0) {
477+
writer.write(TRUE);
478+
} else {
479+
writer.block(() => {
480+
// union all concrete model's guards
481+
writer.writeLine('OR: [');
482+
concreteModels.forEach((concrete) => {
483+
writer.block(() => {
484+
writer.write('AND: [');
485+
// discriminator condition
486+
writer.write(`{ ${discriminator.name}: '${concrete.name}' },`);
487+
// concrete model guard
488+
writer.write(
489+
`{ ${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(
490+
concrete.name
491+
)}: ${getQueryGuardFunctionName(concrete, undefined, false, kind)}(context, db) }`
492+
);
493+
writer.writeLine(']');
494+
});
495+
writer.write(',');
496+
});
497+
writer.writeLine(']');
498+
});
499+
}
500+
writer.write(';');
501+
},
502+
});
503+
504+
return func;
505+
}
506+
441507
export function generateEntityCheckerFunction(
442508
sourceFile: SourceFile,
443509
model: DataModel,

packages/schema/src/plugins/prisma/schema-generator.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import path from 'path';
5757
import semver from 'semver';
5858
import { name } from '.';
5959
import { getStringLiteral } from '../../language-server/validator/utils';
60+
import { getConcreteModels } from '../../utils/ast-utils';
6061
import { execPackage } from '../../utils/exec-utils';
6162
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
6263
import {
@@ -320,9 +321,7 @@ export class PrismaSchemaGenerator {
320321
}
321322

322323
// collect concrete models inheriting this model
323-
const concreteModels = decl.$container.declarations.filter(
324-
(d) => isDataModel(d) && d !== decl && d.superTypes.some((base) => base.ref === decl)
325-
);
324+
const concreteModels = getConcreteModels(decl);
326325

327326
// generate an optional relation field in delegate base model to each concrete model
328327
concreteModels.forEach((concrete) => {

packages/schema/src/utils/ast-utils.ts

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@ import {
22
BinaryExpr,
33
DataModel,
44
DataModelAttribute,
5+
DataModelField,
56
Expression,
67
InheritableNode,
78
isBinaryExpr,
89
isDataModel,
910
isDataModelField,
1011
isInvocationExpr,
1112
isModel,
13+
isReferenceExpr,
1214
isTypeDef,
1315
Model,
1416
ModelImport,
1517
TypeDef,
1618
} from '@zenstackhq/language/ast';
17-
import { getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
19+
import { getAttribute, getInheritanceChain, getRecursiveBases, isDelegateModel, isFromStdlib } from '@zenstackhq/sdk';
1820
import {
1921
AstNode,
2022
copyAstNode,
@@ -310,3 +312,27 @@ export function findUpInheritance(start: DataModel, target: DataModel): DataMode
310312
}
311313
return undefined;
312314
}
315+
316+
/**
317+
* Gets all concrete models that inherit from the given delegate model
318+
*/
319+
export function getConcreteModels(dataModel: DataModel): DataModel[] {
320+
if (!isDelegateModel(dataModel)) {
321+
return [];
322+
}
323+
return dataModel.$container.declarations.filter(
324+
(d): d is DataModel => isDataModel(d) && d !== dataModel && d.superTypes.some((base) => base.ref === dataModel)
325+
);
326+
}
327+
328+
/**
329+
* Gets the discriminator field for the given delegate model
330+
*/
331+
export function getDiscriminatorField(delegate: DataModel) {
332+
const delegateAttr = getAttribute(delegate, '@@delegate');
333+
if (!delegateAttr) {
334+
return undefined;
335+
}
336+
const arg = delegateAttr.args[0]?.value;
337+
return isReferenceExpr(arg) ? (arg.target.ref as DataModelField) : undefined;
338+
}

0 commit comments

Comments
 (0)