Skip to content

Commit 5b3bc1d

Browse files
committed
fix: several issues with using auth() in @default
- Make generated TS field optional if it has a default - Handle the difference between save and unsafe Prisma mutation
1 parent 2e81a08 commit 5b3bc1d

File tree

12 files changed

+271
-69
lines changed

12 files changed

+271
-69
lines changed

packages/runtime/src/cross/model-meta.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,14 @@ export type FieldInfo = {
7575
isForeignKey?: boolean;
7676

7777
/**
78-
* Mapping from foreign key field names to relation field names
78+
* If the field is a foreign key field, the field name of the corresponding relation field.
79+
* Only available on foreign key fields.
80+
*/
81+
relationField?: string;
82+
83+
/**
84+
* Mapping from foreign key field names to relation field names.
85+
* Only available on relation fields.
7986
*/
8087
foreignKeyMapping?: Record<string, string>;
8188

packages/runtime/src/enhancements/default-auth.ts

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
/* eslint-disable @typescript-eslint/no-explicit-any */
33

44
import deepcopy from 'deepcopy';
5-
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross';
5+
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields, requireField } from '../cross';
66
import { DbClientContract } from '../types';
77
import { EnhancementContext, InternalEnhancementOptions } from './create-enhancement';
88
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';
9+
import { isUnsafeMutate } from './utils';
910

1011
/**
1112
* Gets an enhanced Prisma client that supports `@default(auth())` attribute.
@@ -68,7 +69,7 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
6869
const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo);
6970
if (authDefaultValue !== undefined) {
7071
// set field value extracted from `auth()`
71-
data[fieldInfo.name] = authDefaultValue;
72+
this.setAuthDefaultValue(fieldInfo, model, data, authDefaultValue);
7273
}
7374
}
7475
};
@@ -90,6 +91,47 @@ class DefaultAuthHandler extends DefaultPrismaProxyHandler {
9091
return newArgs;
9192
}
9293

94+
private setAuthDefaultValue(fieldInfo: FieldInfo, model: string, data: any, authDefaultValue: unknown) {
95+
if (fieldInfo.isForeignKey && !isUnsafeMutate(model, data, this.options.modelMeta)) {
96+
// if the field is a fk, and the create payload is not unsafe, we need to translate
97+
// the fk field setting to a `connect` of the corresponding relation field
98+
const relFieldName = fieldInfo.relationField;
99+
if (!relFieldName) {
100+
throw new Error(
101+
`Field \`${fieldInfo.name}\` is a foreign key field but no corresponding relation field is found`
102+
);
103+
}
104+
const relationField = requireField(this.options.modelMeta, model, relFieldName);
105+
106+
// construct a `{ connect: { ... } }` payload
107+
let connect = data[relationField.name]?.connect;
108+
if (!connect) {
109+
connect = {};
110+
data[relationField.name] = { connect };
111+
}
112+
113+
// sets the opposite fk field to value `authDefaultValue`
114+
const oppositeFkFieldName = this.getOppositeFkFieldName(relationField, fieldInfo);
115+
if (!oppositeFkFieldName) {
116+
throw new Error(
117+
`Cannot find opposite foreign key field for \`${fieldInfo.name}\` in relation field \`${relFieldName}\``
118+
);
119+
}
120+
connect[oppositeFkFieldName] = authDefaultValue;
121+
} else {
122+
// set default value directly
123+
data[fieldInfo.name] = authDefaultValue;
124+
}
125+
}
126+
127+
private getOppositeFkFieldName(relationField: FieldInfo, fieldInfo: FieldInfo) {
128+
if (!relationField.foreignKeyMapping) {
129+
return undefined;
130+
}
131+
const entry = Object.entries(relationField.foreignKeyMapping).find(([, v]) => v === fieldInfo.name);
132+
return entry?.[0];
133+
}
134+
93135
private getDefaultValueFromAuth(fieldInfo: FieldInfo) {
94136
if (!this.userContext) {
95137
throw new Error(`Evaluating default value of field \`${fieldInfo.name}\` requires a user context`);

packages/runtime/src/enhancements/policy/handler.ts

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import type { EnhancementContext, InternalEnhancementOptions } from '../create-e
2121
import { Logger } from '../logger';
2222
import { PrismaProxyHandler } from '../proxy';
2323
import { QueryUtils } from '../query-utils';
24-
import { formatObject, prismaClientValidationError } from '../utils';
24+
import { formatObject, isUnsafeMutate, prismaClientValidationError } from '../utils';
2525
import { PolicyUtil } from './policy-utils';
2626
import { createDeferredPromise } from './promise';
2727

@@ -691,7 +691,7 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
691691
// operations. E.g.:
692692
// - safe: { data: { user: { connect: { id: 1 }} } }
693693
// - unsafe: { data: { userId: 1 } }
694-
const unsafe = this.isUnsafeMutate(model, args);
694+
const unsafe = isUnsafeMutate(model, args, this.modelMeta);
695695

696696
// handles the connection to upstream entity
697697
const reversedQuery = this.policyUtils.buildReversedQuery(context, true, unsafe);
@@ -1083,23 +1083,6 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
10831083
}
10841084
}
10851085

1086-
private isUnsafeMutate(model: string, args: any) {
1087-
if (!args) {
1088-
return false;
1089-
}
1090-
for (const k of Object.keys(args)) {
1091-
const field = resolveField(this.modelMeta, model, k);
1092-
if (field && (this.isAutoIncrementIdField(field) || field.isForeignKey)) {
1093-
return true;
1094-
}
1095-
}
1096-
return false;
1097-
}
1098-
1099-
private isAutoIncrementIdField(field: FieldInfo) {
1100-
return field.isId && field.isAutoIncrement;
1101-
}
1102-
11031086
async updateMany(args: any) {
11041087
if (!args) {
11051088
throw prismaClientValidationError(this.prisma, this.prismaModule, 'query argument is required');

packages/runtime/src/enhancements/utils.ts

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import * as util from 'util';
2+
import { FieldInfo, ModelMeta, resolveField } from '..';
23
import type { DbClientContract } from '../types';
34

45
/**
@@ -22,3 +23,21 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo
2223
export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error {
2324
throw new prismaModule.PrismaClientUnknownRequestError(...args);
2425
}
26+
27+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
28+
export function isUnsafeMutate(model: string, args: any, modelMeta: ModelMeta) {
29+
if (!args) {
30+
return false;
31+
}
32+
for (const k of Object.keys(args)) {
33+
const field = resolveField(modelMeta, model, k);
34+
if (field && (isAutoIncrementIdField(field) || field.isForeignKey)) {
35+
return true;
36+
}
37+
}
38+
return false;
39+
}
40+
41+
export function isAutoIncrementIdField(field: FieldInfo) {
42+
return field.isId && field.isAutoIncrement;
43+
}

packages/schema/src/plugins/enhancer/enhance/index.ts

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import { name } from '..';
2727
import { execPackage } from '../../../utils/exec-utils';
2828
import { trackPrismaSchemaError } from '../../prisma';
2929
import { PrismaSchemaGenerator } from '../../prisma/schema-generator';
30+
import { isDefaultWithAuth } from '../enhancer-utils';
3031

3132
// information of delegate models and their sub models
3233
type DelegateInfo = [DataModel, DataModel[]][];
@@ -35,7 +36,7 @@ export async function generate(model: Model, options: PluginOptions, project: Pr
3536
let logicalPrismaClientDir: string | undefined;
3637
let dmmf: DMMF.Document | undefined;
3738

38-
if (hasDelegateModel(model)) {
39+
if (needsLogicalClient(model)) {
3940
// schema contains delegate models, need to generate a logical prisma schema
4041
const result = await generateLogicalPrisma(model, options, outDir);
4142

@@ -86,13 +87,23 @@ export function enhance<DbClient extends object>(prisma: DbClient, context?: Enh
8687
return { dmmf };
8788
}
8889

90+
function needsLogicalClient(model: Model) {
91+
return hasDelegateModel(model) || hasAuthInDefault(model);
92+
}
93+
8994
function hasDelegateModel(model: Model) {
9095
const dataModels = getDataModels(model);
9196
return dataModels.some(
9297
(dm) => isDelegateModel(dm) && dataModels.some((sub) => sub.superTypes.some((base) => base.ref === dm))
9398
);
9499
}
95100

101+
function hasAuthInDefault(model: Model) {
102+
return getDataModels(model).some((dm) =>
103+
dm.fields.some((f) => f.attributes.some((attr) => isDefaultWithAuth(attr)))
104+
);
105+
}
106+
96107
async function generateLogicalPrisma(model: Model, options: PluginOptions, outDir: string) {
97108
const prismaGenerator = new PrismaSchemaGenerator(model);
98109
const prismaClientOutDir = './.logical-prisma-client';
@@ -152,12 +163,19 @@ async function processClientTypes(model: Model, prismaClientDir: string) {
152163
const sfNew = project.createSourceFile(path.join(prismaClientDir, 'index-fixed.d.ts'), undefined, {
153164
overwrite: true,
154165
});
155-
transform(sf, sfNew, delegateInfo);
156-
sfNew.formatText();
166+
167+
if (delegateInfo.length > 0) {
168+
// transform types for delegated models
169+
transformDelegate(sf, sfNew, delegateInfo);
170+
sfNew.formatText();
171+
} else {
172+
// just copy
173+
sfNew.replaceWithText(sf.getFullText());
174+
}
157175
await sfNew.save();
158176
}
159177

160-
function transform(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
178+
function transformDelegate(sf: SourceFile, sfNew: SourceFile, delegateModels: DelegateInfo) {
161179
// copy toplevel imports
162180
sfNew.addImportDeclarations(sf.getImportDeclarations().map((n) => n.getStructure()));
163181

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import { isAuthInvocation } from '@zenstackhq/sdk';
2+
import type { DataModelFieldAttribute } from '@zenstackhq/sdk/ast';
3+
import { streamAst } from 'langium';
4+
5+
/**
6+
* Check if the given field attribute is a `@default` with `auth()` invocation
7+
*/
8+
export function isDefaultWithAuth(attr: DataModelFieldAttribute) {
9+
if (attr.decl.ref?.name !== '@default') {
10+
return false;
11+
}
12+
13+
const expr = attr.args[0]?.value;
14+
if (!expr) {
15+
return false;
16+
}
17+
18+
// find `auth()` in default value expression
19+
return streamAst(expr).some(isAuthInvocation);
20+
}

packages/schema/src/plugins/prisma/schema-generator.ts

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,27 @@ import { getIdFields } from '../../utils/ast-utils';
3434
import { DELEGATE_AUX_RELATION_PREFIX, PRISMA_MINIMUM_VERSION } from '@zenstackhq/runtime';
3535
import {
3636
getAttribute,
37+
getForeignKeyFields,
3738
getLiteral,
3839
getPrismaVersion,
39-
isAuthInvocation,
4040
isDelegateModel,
4141
isIdField,
42+
isRelationshipField,
4243
PluginError,
4344
PluginOptions,
4445
resolved,
4546
ZModelCodeGenerator,
4647
} from '@zenstackhq/sdk';
4748
import fs from 'fs';
4849
import { writeFile } from 'fs/promises';
49-
import { streamAst } from 'langium';
5050
import { lowerCaseFirst } from 'lower-case-first';
5151
import path from 'path';
5252
import semver from 'semver';
5353
import { upperCaseFirst } from 'upper-case-first';
5454
import { name } from '.';
5555
import { getStringLiteral } from '../../language-server/validator/utils';
5656
import { execPackage } from '../../utils/exec-utils';
57+
import { isDefaultWithAuth } from '../enhancer/enhancer-utils';
5758
import {
5859
AttributeArgValue,
5960
ModelFieldType,
@@ -494,10 +495,27 @@ export class PrismaSchemaGenerator {
494495

495496
const type = new ModelFieldType(fieldType, field.type.array, field.type.optional);
496497

498+
if (this.mode === 'logical') {
499+
if (field.attributes.some((attr) => isDefaultWithAuth(attr))) {
500+
// field has `@default` with `auth()`, it should be set optional, and the
501+
// default value setting is handled outside Prisma
502+
type.optional = true;
503+
}
504+
505+
if (isRelationshipField(field)) {
506+
// if foreign key field has `@default` with `auth()`, the relation
507+
// field should be set optional
508+
const foreignKeyFields = getForeignKeyFields(field);
509+
if (foreignKeyFields.some((fkField) => fkField.attributes.some((attr) => isDefaultWithAuth(attr)))) {
510+
type.optional = true;
511+
}
512+
}
513+
}
514+
497515
const attributes = field.attributes
498516
.filter((attr) => this.isPrismaAttribute(attr))
499517
// `@default` with `auth()` is handled outside Prisma
500-
.filter((attr) => !this.isDefaultWithAuth(attr))
518+
.filter((attr) => !isDefaultWithAuth(attr))
501519
.filter(
502520
(attr) =>
503521
// when building physical schema, exclude `@default` for id fields inherited from delegate base
@@ -524,20 +542,6 @@ export class PrismaSchemaGenerator {
524542
return field.$inheritedFrom && isDelegateModel(field.$inheritedFrom);
525543
}
526544

527-
private isDefaultWithAuth(attr: DataModelFieldAttribute) {
528-
if (attr.decl.ref?.name !== '@default') {
529-
return false;
530-
}
531-
532-
const expr = attr.args[0]?.value;
533-
if (!expr) {
534-
return false;
535-
}
536-
537-
// find `auth()` in default value expression
538-
return streamAst(expr).some(isAuthInvocation);
539-
}
540-
541545
private makeFieldAttribute(attr: DataModelFieldAttribute) {
542546
const attrName = resolved(attr.decl).name;
543547
if (attrName === FIELD_PASSTHROUGH_ATTR) {

packages/schema/tests/schema/validation/attribute-validation.test.ts

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ describe('Attribute tests', () => {
227227
`);
228228

229229
await loadModel(`
230-
${ prelude }
230+
${prelude}
231231
model A {
232232
id String @id
233233
x String
@@ -1051,21 +1051,6 @@ describe('Attribute tests', () => {
10511051
}
10521052
`);
10531053

1054-
// expect(
1055-
// await loadModelWithError(`
1056-
// ${prelude}
1057-
1058-
// model User {
1059-
// id String @id
1060-
// name String
1061-
// }
1062-
// model B {
1063-
// id String @id
1064-
// userData String @default(auth())
1065-
// }
1066-
// `)
1067-
// ).toContain("Value is not assignable to parameter");
1068-
10691054
expect(
10701055
await loadModelWithError(`
10711056
${prelude}
@@ -1185,15 +1170,6 @@ describe('Attribute tests', () => {
11851170
});
11861171

11871172
it('incorrect function expression context', async () => {
1188-
// expect(
1189-
// await loadModelWithError(`
1190-
// ${prelude}
1191-
// model M {
1192-
// id String @id @default(auth())
1193-
// }
1194-
// `)
1195-
// ).toContain('function "auth" is not allowed in the current context: DefaultValue');
1196-
11971173
expect(
11981174
await loadModelWithError(`
11991175
${prelude}

packages/sdk/src/model-meta-generator.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ import {
3232
isIdField,
3333
resolved,
3434
TypeScriptExpressionTransformer,
35+
getRelationField,
3536
} from '.';
3637

3738
/**
@@ -247,6 +248,11 @@ function writeFields(
247248
if (isForeignKeyField(f)) {
248249
writer.write(`
249250
isForeignKey: true,`);
251+
const relationField = getRelationField(f);
252+
if (relationField) {
253+
writer.write(`
254+
relationField: '${relationField.name}',`);
255+
}
250256
}
251257

252258
if (fkMapping && Object.keys(fkMapping).length > 0) {
@@ -408,7 +414,6 @@ function generateForeignKeyMapping(field: DataModelField) {
408414
const fieldNames = fields.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));
409415
const referenceNames = references.items.map((item) => (isReferenceExpr(item) ? item.target.$refText : undefined));
410416

411-
// eslint-disable-next-line @typescript-eslint/no-explicit-any
412417
const result: Record<string, string> = {};
413418
referenceNames.forEach((name, i) => {
414419
if (name) {

0 commit comments

Comments
 (0)