Skip to content

Commit d6618c9

Browse files
committed
more fixes
1 parent ff86ce0 commit d6618c9

File tree

4 files changed

+46
-125
lines changed

4 files changed

+46
-125
lines changed

packages/schema/src/plugins/enhancer/policy/policy-guard-generator.ts

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import {
2525
hasAttribute,
2626
hasValidationAttributes,
2727
isAuthInvocation,
28-
isDelegateModel,
2928
isForeignKeyField,
3029
saveSourceFile,
3130
} from '@zenstackhq/sdk';
@@ -455,44 +454,38 @@ export class PolicyGenerator {
455454
writer: CodeBlockWriter,
456455
sourceFile: SourceFile
457456
) {
458-
const isDelegate = isDelegateModel(model);
459-
460-
if (!isDelegate) {
461-
// handle cases where a constant function can be used
462-
// note that this doesn't apply to delegate models because
463-
// all concrete models inheriting it need to be considered
464-
465-
if (kind === 'update' && allows.length === 0) {
466-
// no allow rule for 'update', policy is constant based on if there's
467-
// post-update counterpart
468-
let func: FunctionDeclaration;
469-
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
470-
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
471-
} else {
472-
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
473-
}
474-
writer.write(`guard: ${func.getName()!},`);
475-
return;
476-
}
457+
// first handle several cases where a constant function can be used
477458

478-
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
479-
// no 'postUpdate' rule, always allow
480-
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
481-
writer.write(`guard: ${func.getName()},`);
482-
return;
459+
if (kind === 'update' && allows.length === 0) {
460+
// no allow rule for 'update', policy is constant based on if there's
461+
// post-update counterpart
462+
let func: FunctionDeclaration;
463+
if (getPolicyExpressions(model, 'allow', 'postUpdate').length === 0) {
464+
func = generateConstantQueryGuardFunction(sourceFile, model, kind, false);
465+
} else {
466+
func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
483467
}
468+
writer.write(`guard: ${func.getName()!},`);
469+
return;
470+
}
484471

485-
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
486-
// constant policy
487-
const func = generateConstantQueryGuardFunction(
488-
sourceFile,
489-
model,
490-
kind,
491-
policies[kind as keyof typeof policies] as boolean
492-
);
493-
writer.write(`guard: ${func.getName()!},`);
494-
return;
495-
}
472+
if (kind === 'postUpdate' && allows.length === 0 && denies.length === 0) {
473+
// no 'postUpdate' rule, always allow
474+
const func = generateConstantQueryGuardFunction(sourceFile, model, kind, true);
475+
writer.write(`guard: ${func.getName()},`);
476+
return;
477+
}
478+
479+
if (kind in policies && typeof policies[kind as keyof typeof policies] === 'boolean') {
480+
// constant policy
481+
const func = generateConstantQueryGuardFunction(
482+
sourceFile,
483+
model,
484+
kind,
485+
policies[kind as keyof typeof policies] as boolean
486+
);
487+
writer.write(`guard: ${func.getName()!},`);
488+
return;
496489
}
497490

498491
// generate a policy function that evaluates a partial prisma query

packages/schema/src/plugins/enhancer/policy/utils.ts

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/* eslint-disable @typescript-eslint/no-explicit-any */
2-
import { DELEGATE_AUX_RELATION_PREFIX, type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
2+
import { type PolicyKind, type PolicyOperationKind } from '@zenstackhq/runtime';
33
import {
44
ExpressionContext,
55
PluginError,
@@ -15,7 +15,6 @@ import {
1515
getQueryGuardFunctionName,
1616
isAuthInvocation,
1717
isDataModelFieldReference,
18-
isDelegateModel,
1918
isEnumFieldReference,
2019
isFromStdlib,
2120
isFutureExpr,
@@ -40,16 +39,9 @@ import {
4039
} from '@zenstackhq/sdk/ast';
4140
import deepmerge from 'deepmerge';
4241
import { getContainerOfType, streamAllContents, streamAst, streamContents } from 'langium';
43-
import { lowerCaseFirst } from 'lower-case-first';
4442
import { SourceFile, WriterFunction } from 'ts-morph';
4543
import { name } from '..';
46-
import {
47-
getConcreteModels,
48-
getDiscriminatorField,
49-
isCheckInvocation,
50-
isCollectionPredicate,
51-
isFutureInvocation,
52-
} from '../../../utils/ast-utils';
44+
import { isCheckInvocation, isCollectionPredicate, isFutureInvocation } from '../../../utils/ast-utils';
5345
import { ExpressionWriter, FALSE, TRUE } from './expression-writer';
5446

5547
/**
@@ -311,10 +303,6 @@ export function generateQueryGuardFunction(
311303
forField?: DataModelField,
312304
fieldOverride = false
313305
) {
314-
if (isDelegateModel(model) && !forField) {
315-
return generateDelegateQueryGuardFunction(sourceFile, model, kind);
316-
}
317-
318306
const statements: (string | WriterFunction)[] = [];
319307
const allowRules = allows.filter((rule) => !hasCrossModelComparison(rule));
320308
const denyRules = denies.filter((rule) => !hasCrossModelComparison(rule));
@@ -449,61 +437,6 @@ export function generateQueryGuardFunction(
449437
return func;
450438
}
451439

452-
function generateDelegateQueryGuardFunction(sourceFile: SourceFile, model: DataModel, kind: PolicyOperationKind) {
453-
const concreteModels = getConcreteModels(model);
454-
455-
const discriminator = getDiscriminatorField(model);
456-
if (!discriminator) {
457-
throw new PluginError(name, `Model '${model.name}' does not have a discriminator field`);
458-
}
459-
460-
const func = sourceFile.addFunction({
461-
name: getQueryGuardFunctionName(model, undefined, false, kind),
462-
returnType: 'any',
463-
parameters: [
464-
{
465-
name: 'context',
466-
type: 'QueryContext',
467-
},
468-
{
469-
// for generating field references used by field comparison in the same model
470-
name: 'db',
471-
type: 'CrudContract',
472-
},
473-
],
474-
statements: (writer) => {
475-
writer.write('return ');
476-
if (concreteModels.length === 0) {
477-
writer.write(TRUE);
478-
} else {
479-
writer.block(() => {
480-
// union all concrete model's guards
481-
writer.writeLine('OR: [');
482-
concreteModels.forEach((concrete) => {
483-
writer.block(() => {
484-
writer.write('AND: [');
485-
// discriminator condition
486-
writer.write(`{ ${discriminator.name}: '${concrete.name}' },`);
487-
// concrete model guard
488-
writer.write(
489-
`{ ${DELEGATE_AUX_RELATION_PREFIX}_${lowerCaseFirst(
490-
concrete.name
491-
)}: ${getQueryGuardFunctionName(concrete, undefined, false, kind)}(context, db) }`
492-
);
493-
writer.writeLine(']');
494-
});
495-
writer.write(',');
496-
});
497-
writer.writeLine(']');
498-
});
499-
}
500-
writer.write(';');
501-
},
502-
});
503-
504-
return func;
505-
}
506-
507440
export function generateEntityCheckerFunction(
508441
sourceFile: SourceFile,
509442
model: DataModel,

tests/integration/tests/enhancements/with-delegate/policy-interaction.test.ts

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,16 @@ describe('Polymorphic Policy Test', () => {
8787
`;
8888

8989
for (const schema of [booleanCondition, booleanExpression]) {
90-
const { enhanceRaw: enhance, prisma } = await loadSchema(schema);
90+
const { enhanceRaw: enhance, prisma } = await loadSchema(schema, { logPrismaQuery: true });
9191

9292
const fullDb = enhance(prisma, undefined, { kinds: ['delegate'] });
9393

9494
const user = await fullDb.user.create({ data: { id: 1 } });
95-
const userDb = enhance(prisma, { user: { id: user.id } }, { kinds: ['delegate', 'policy'] });
95+
const userDb = enhance(
96+
prisma,
97+
{ user: { id: user.id } },
98+
{ kinds: ['delegate', 'policy'], logPrismaQuery: true }
99+
);
96100

97101
// violating Asset create
98102
await expect(
@@ -588,13 +592,14 @@ describe('Polymorphic Policy Test', () => {
588592
type String
589593
590594
@@delegate(type)
595+
@@allow('all', true)
591596
}
592597
593598
model Post extends Asset {
594599
title String
595600
private Boolean
596601
@@allow('create', true)
597-
@@allow('read', !private)
602+
@@deny('read', private)
598603
}
599604
`
600605
);
@@ -607,9 +612,9 @@ describe('Polymorphic Policy Test', () => {
607612
});
608613

609614
const db = enhance();
610-
await expect(db.user.findUnique({ where: { id: 1 }, include: { asset: true } })).resolves.toMatchObject({
611-
asset: null,
612-
});
615+
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
616+
expect(read.asset).toBeTruthy();
617+
expect(read.asset.title).toBeUndefined();
613618
});
614619

615620
it('respects concrete policies when read as base required relation', async () => {
@@ -636,8 +641,7 @@ describe('Polymorphic Policy Test', () => {
636641
private Boolean
637642
@@deny('read', private)
638643
}
639-
`,
640-
{ logPrismaQuery: true }
644+
`
641645
);
642646

643647
const fullDb = enhance(undefined, { kinds: ['delegate'] });
@@ -647,6 +651,8 @@ describe('Polymorphic Policy Test', () => {
647651
});
648652

649653
const db = enhance();
650-
await expect(db.user.findUnique({ where: { id: 1 }, include: { asset: true } })).toResolveNull();
654+
const read = await db.user.findUnique({ where: { id: 1 }, include: { asset: true } });
655+
expect(read).toBeTruthy();
656+
expect(read.asset.title).toBeUndefined();
651657
});
652658
});

tests/regression/tests/issue-1930.test.ts

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ model EntityContent {
3838
}
3939
4040
model Article extends Entity {
41-
private Boolean @default(false)
42-
@@deny('all', private)
4341
}
4442
4543
model ArticleContent extends EntityContent {
@@ -78,14 +76,5 @@ model OtherContent extends EntityContent {
7876
data: { body: 'bcd', entity: { connect: { id: deletedArticle.id } } },
7977
});
8078
await expect(db.articleContent.findUnique({ where: { id: content1.id } })).toResolveNull();
81-
82-
// private article's contents are not readable
83-
const privateArticle = await fullDb.article.create({
84-
data: { org: { connect: { id: org.id } }, private: true },
85-
});
86-
const content2 = await fullDb.articleContent.create({
87-
data: { body: 'cde', entity: { connect: { id: privateArticle.id } } },
88-
});
89-
await expect(db.articleContent.findUnique({ where: { id: content2.id } })).toResolveNull();
9079
});
9180
});

0 commit comments

Comments
 (0)