Skip to content

Commit 08b9677

Browse files
authored
feat: runtime support for custom @@auth model (#793)
1 parent c390de1 commit 08b9677

File tree

10 files changed

+114
-22
lines changed

10 files changed

+114
-22
lines changed

packages/runtime/src/cross/model-meta.ts

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,25 @@ export type UniqueConstraint = { name: string; fields: string[] };
7878
* ZModel data model metadata
7979
*/
8080
export type ModelMeta = {
81+
/**
82+
* Model fields
83+
*/
8184
fields: Record<string, Record<string, FieldInfo>>;
85+
86+
/**
87+
* Model unique constraints
88+
*/
8289
uniqueConstraints: Record<string, Record<string, UniqueConstraint>>;
90+
91+
/**
92+
* Information for cascading delete
93+
*/
8394
deleteCascade: Record<string, string[]>;
95+
96+
/**
97+
* Name of model that backs the `auth()` function
98+
*/
99+
authModel?: string;
84100
};
85101

86102
/**

packages/runtime/src/enhancements/policy/index.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@ export function withPolicy<DbClient extends object>(
7373
const _zodSchemas = options?.zodSchemas ?? getDefaultZodSchemas(options?.loadPath);
7474

7575
// validate user context
76-
if (context?.user) {
77-
const idFields = getIdFields(_modelMeta, 'User');
76+
if (context?.user && _modelMeta.authModel) {
77+
const idFields = getIdFields(_modelMeta, _modelMeta.authModel);
7878
if (
7979
!hasAllFields(
8080
context.user,

packages/schema/src/cli/cli-util.ts

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { isDataSource, isPlugin, Model } from '@zenstackhq/language/ast';
2-
import { getLiteral } from '@zenstackhq/sdk';
2+
import { getDataModels, getLiteral, hasAttribute } from '@zenstackhq/sdk';
33
import colors from 'colors';
44
import fs from 'fs';
55
import getLatestVersion from 'get-latest-version';
@@ -95,10 +95,18 @@ export async function loadDocument(fileName: string): Promise<Model> {
9595
function validationAfterMerge(model: Model) {
9696
const dataSources = model.declarations.filter((d) => isDataSource(d));
9797
if (dataSources.length == 0) {
98-
console.error(colors.red('Validation errors: Model must define a datasource'));
98+
console.error(colors.red('Validation error: Model must define a datasource'));
9999
throw new CliError('schema validation errors');
100100
} else if (dataSources.length > 1) {
101-
console.error(colors.red('Validation errors: Multiple datasource declarations are not allowed'));
101+
console.error(colors.red('Validation error: Multiple datasource declarations are not allowed'));
102+
throw new CliError('schema validation errors');
103+
}
104+
105+
// at most one `@@auth` model
106+
const dataModels = getDataModels(model);
107+
const authModels = dataModels.filter((d) => hasAttribute(d, '@@auth'));
108+
if (authModels.length > 1) {
109+
console.error(colors.red('Validation error: Multiple `@@auth` models are not allowed'));
102110
throw new CliError('schema validation errors');
103111
}
104112
}

packages/schema/src/language-server/validator/schema-validator.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import { Model, isDataModel, isDataSource } from '@zenstackhq/language/ast';
2+
import { hasAttribute } from '@zenstackhq/sdk';
3+
import { LangiumDocuments, ValidationAcceptor } from 'langium';
4+
import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils';
15
import { PLUGIN_MODULE_NAME, STD_LIB_MODULE_NAME } from '../constants';
2-
import { isDataSource, Model } from '@zenstackhq/language/ast';
36
import { AstValidator } from '../types';
4-
import { LangiumDocuments, ValidationAcceptor } from 'langium';
57
import { validateDuplicatedDeclarations } from './utils';
6-
import { getAllDeclarationsFromImports, resolveImport, resolveTransitiveImports } from '../../utils/ast-utils';
78

89
/**
910
* Validates toplevel schema.
@@ -33,6 +34,12 @@ export default class SchemaValidator implements AstValidator<Model> {
3334
) {
3435
this.validateDataSources(model, accept);
3536
}
37+
38+
// at most one `@@auth` model
39+
const authModels = model.declarations.filter((d) => isDataModel(d) && hasAttribute(d, '@@auth'));
40+
if (authModels.length > 1) {
41+
accept('error', 'Multiple `@@auth` models are not allowed', { node: authModels[1] });
42+
}
3643
}
3744

3845
private validateDataSources(model: Model, accept: ValidationAcceptor) {

packages/schema/src/language-server/zmodel-linker.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,16 @@ export class ZModelLinker extends DefaultLinker {
278278
const model = getContainingModel(node);
279279

280280
if (model) {
281-
let userModel;
282-
userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
281+
let authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
283282
return isDataModel(d) && hasAttribute(d, '@@auth');
284283
});
285-
if (!userModel) {
286-
userModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
284+
if (!authModel) {
285+
authModel = getAllDeclarationsFromImports(this.langiumDocuments(), model).find((d) => {
287286
return isDataModel(d) && d.name === 'User';
288287
});
289288
}
290-
if (userModel) {
291-
node.$resolvedType = { decl: userModel, nullable: true };
289+
if (authModel) {
290+
node.$resolvedType = { decl: authModel, nullable: true };
292291
}
293292
}
294293
} else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) {

packages/schema/src/plugins/access-policy/policy-guard-generator.ts

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import {
3535
analyzePolicies,
3636
createProject,
3737
emitProject,
38+
getAuthModel,
3839
getDataModels,
3940
getLiteral,
4041
getPrismaClientImportSpec,
@@ -744,13 +745,11 @@ export default class PolicyGenerator {
744745
);
745746

746747
if (hasAuthRef) {
747-
const userModel = model.$container.declarations.find(
748-
(decl): decl is DataModel => isDataModel(decl) && decl.name === 'User'
749-
);
750-
if (!userModel) {
751-
throw new PluginError(name, 'User model not found');
748+
const authModel = getAuthModel(getDataModels(model.$container));
749+
if (!authModel) {
750+
throw new PluginError(name, 'Auth model not found');
752751
}
753-
const userIdFields = getIdFields(userModel);
752+
const userIdFields = getIdFields(authModel);
754753
if (!userIdFields || userIdFields.length === 0) {
755754
throw new PluginError(name, 'User model does not have an id field');
756755
}

packages/schema/tests/schema/validation/schema-validation.test.ts

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,25 @@ describe('Toplevel Schema Validation Tests', () => {
3838
`)
3939
).toContain('Cannot find model file models/abc.zmodel');
4040
});
41+
42+
it('multiple auth models', async () => {
43+
expect(
44+
await loadModelWithError(`
45+
datasource db1 {
46+
provider = 'postgresql'
47+
url = env('DATABASE_URL')
48+
}
49+
50+
model X {
51+
id String @id
52+
@@auth
53+
}
54+
55+
model Y {
56+
id String @id
57+
@@auth
58+
}
59+
`)
60+
).toContain('Multiple `@@auth` models are not allowed');
61+
});
4162
});

packages/sdk/src/model-meta-generator.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@ import { lowerCaseFirst } from 'lower-case-first';
1515
import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph';
1616
import {
1717
emitProject,
18+
getAttribute,
1819
getAttributeArg,
1920
getAttributeArgs,
21+
getAuthModel,
2022
getDataModels,
2123
getLiteral,
2224
hasAttribute,
25+
isEnumFieldReference,
2326
isForeignKeyField,
2427
isIdField,
2528
resolved,
2629
saveProject,
27-
getAttribute,
28-
isEnumFieldReference,
2930
} from '.';
3031

3132
export async function generate(
@@ -113,6 +114,12 @@ function generateModelMetadata(dataModels: DataModel[], writer: CodeBlockWriter)
113114
}
114115
}
115116
});
117+
writer.write(',');
118+
119+
const authModel = getAuthModel(dataModels);
120+
if (authModel) {
121+
writer.writeLine(`authModel: '${authModel.name}'`);
122+
}
116123
});
117124
}
118125

packages/sdk/src/utils.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,3 +352,11 @@ export function getPreviewFeatures(model: Model) {
352352

353353
return [] as string[];
354354
}
355+
356+
export function getAuthModel(dataModels: DataModel[]) {
357+
let authModel = dataModels.find((m) => hasAttribute(m, '@@auth'));
358+
if (!authModel) {
359+
authModel = dataModels.find((m) => m.name === 'User');
360+
}
361+
return authModel;
362+
}

tests/integration/tests/enhancements/with-policy/auth.test.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,4 +185,31 @@ describe('With Policy: auth() test', () => {
185185
const authDb1 = withPolicy({ id: 'user2', role: 'ADMIN' });
186186
await expect(authDb1.post.update({ where: { id: '1' }, data: { title: 'bcd' } })).toResolveTruthy();
187187
});
188+
189+
it('non User auth model', async () => {
190+
const { withPolicy } = await loadSchema(
191+
`
192+
model Foo {
193+
id String @id @default(uuid())
194+
role String
195+
196+
@@auth()
197+
}
198+
199+
model Post {
200+
id String @id @default(uuid())
201+
title String
202+
203+
@@allow('read', true)
204+
@@allow('create', auth().role == 'ADMIN')
205+
}
206+
`
207+
);
208+
209+
const userDb = withPolicy({ id: 'user1', role: 'USER' });
210+
await expect(userDb.post.create({ data: { title: 'abc' } })).toBeRejectedByPolicy();
211+
212+
const adminDb = withPolicy({ id: 'user1', role: 'ADMIN' });
213+
await expect(adminDb.post.create({ data: { title: 'abc' } })).toResolveTruthy();
214+
});
188215
});

0 commit comments

Comments
 (0)