Skip to content

Support for auth() in @default attribute #958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/ide/jetbrains/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"homepage": "https://zenstack.dev",
"private": true,
"scripts": {
"build": "./gradlew buildPlugin"
"build": "./gradlew buildPlugin"
},
"author": "ZenStack Team",
"license": "MIT",
Expand Down
10 changes: 10 additions & 0 deletions packages/runtime/src/cross/model-meta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ export type RuntimeAttribute = {
args: Array<{ name?: string; value: unknown }>;
};

/**
* Function for computing default value for a field
*/
export type FieldDefaultValueProvider = (userContext: unknown) => unknown;

/**
* Runtime information of a data model field
*/
Expand Down Expand Up @@ -67,6 +72,11 @@ export type FieldInfo = {
*/
foreignKeyMapping?: Record<string, string>;

/**
* A function that provides a default value for the field
*/
defaultValueProvider?: FieldDefaultValueProvider;

/**
* If the field is an auto-increment field
*/
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/src/cross/nested-write-visitor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ export type NestedWriteVisitorContext = {
* to let the visitor traverse it instead of its original children.
*/
export type NestedWriterVisitorCallback = {
create?: (model: string, args: any[], context: NestedWriteVisitorContext) => MaybePromise<boolean | object | void>;
create?: (model: string, data: any, context: NestedWriteVisitorContext) => MaybePromise<boolean | object | void>;

createMany?: (
model: string,
Expand Down
23 changes: 21 additions & 2 deletions packages/runtime/src/enhancements/create-enhancement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { withPassword } from './password';
import { withPolicy } from './policy';
import type { ErrorTransformer } from './proxy';
import type { PolicyDef, ZodSchemas } from './types';
import { withDefaultAuth } from './default-auth';

/**
* Kinds of enhancements to `PrismaClient`
Expand All @@ -15,6 +16,7 @@ export enum EnhancementKind {
Password = 'password',
Omit = 'omit',
Policy = 'policy',
DefaultAuth = 'defaultAuth',
}

/**
Expand Down Expand Up @@ -92,6 +94,7 @@ export type EnhancementContext = {

let hasPassword: boolean | undefined = undefined;
let hasOmit: boolean | undefined = undefined;
let hasDefaultAuth: boolean | undefined = undefined;

/**
* Gets a Prisma client enhanced with all enhancement behaviors, including access
Expand Down Expand Up @@ -120,13 +123,24 @@ export function createEnhancement<DbClient extends object>(

let result = prisma;

if (hasPassword === undefined || hasOmit === undefined) {
if (
process.env.ZENSTACK_TEST === '1' || // avoid caching in tests
hasPassword === undefined ||
hasOmit === undefined ||
hasDefaultAuth === undefined
) {
const allFields = Object.values(options.modelMeta.fields).flatMap((modelInfo) => Object.values(modelInfo));
hasPassword = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@password'));
hasOmit = allFields.some((field) => field.attributes?.some((attr) => attr.name === '@omit'));
hasDefaultAuth = allFields.some((field) => field.defaultValueProvider);
}

const kinds = options.kinds ?? [EnhancementKind.Password, EnhancementKind.Omit, EnhancementKind.Policy];
const kinds = options.kinds ?? [
EnhancementKind.Password,
EnhancementKind.Omit,
EnhancementKind.Policy,
EnhancementKind.DefaultAuth,
];

if (hasPassword && kinds.includes(EnhancementKind.Password)) {
// @password proxy
Expand All @@ -138,6 +152,11 @@ export function createEnhancement<DbClient extends object>(
result = withOmit(result, options);
}

if (hasDefaultAuth && kinds.includes(EnhancementKind.DefaultAuth)) {
// @default(auth()) proxy
result = withDefaultAuth(result, options, context);
}

// policy proxy
if (kinds.includes(EnhancementKind.Policy)) {
result = withPolicy(result, options, context);
Expand Down
102 changes: 102 additions & 0 deletions packages/runtime/src/enhancements/default-auth.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* eslint-disable @typescript-eslint/no-unused-vars */
/* eslint-disable @typescript-eslint/no-explicit-any */

import deepcopy from 'deepcopy';
import { FieldInfo, NestedWriteVisitor, PrismaWriteActionType, enumerate, getFields } from '../cross';
import { DbClientContract } from '../types';
import { EnhancementContext, EnhancementOptions } from './create-enhancement';
import { DefaultPrismaProxyHandler, PrismaProxyActions, makeProxy } from './proxy';

/**
* Gets an enhanced Prisma client that supports `@default(auth())` attribute.
*
* @private
*/
export function withDefaultAuth<DbClient extends object>(
prisma: DbClient,
options: EnhancementOptions,
context?: EnhancementContext
): DbClient {
return makeProxy(
prisma,
options.modelMeta,
(_prisma, model) => new DefaultAuthHandler(_prisma as DbClientContract, model, options, context),
'defaultAuth'
);
}

class DefaultAuthHandler extends DefaultPrismaProxyHandler {
private readonly db: DbClientContract;
private readonly userContext: any;

constructor(
prisma: DbClientContract,
model: string,
private readonly options: EnhancementOptions,
private readonly context?: EnhancementContext
) {
super(prisma, model);
this.db = prisma;

if (!this.context?.user) {
throw new Error(`Using \`auth()\` in \`@default\` requires a user context`);
}

this.userContext = this.context.user;
}

// base override
protected async preprocessArgs(action: PrismaProxyActions, args: any) {
const actionsOfInterest: PrismaProxyActions[] = ['create', 'createMany', 'update', 'updateMany', 'upsert'];
if (actionsOfInterest.includes(action)) {
const newArgs = await this.preprocessWritePayload(this.model, action as PrismaWriteActionType, args);
return newArgs;
}
return args;
}

private async preprocessWritePayload(model: string, action: PrismaWriteActionType, args: any) {
const newArgs = deepcopy(args);

const processCreatePayload = (model: string, data: any) => {
const fields = getFields(this.options.modelMeta, model);
for (const fieldInfo of Object.values(fields)) {
if (fieldInfo.name in data) {
// create payload already sets field value
continue;
}

if (!fieldInfo.defaultValueProvider) {
// field doesn't have a runtime default value provider
continue;
}

const authDefaultValue = this.getDefaultValueFromAuth(fieldInfo);
if (authDefaultValue !== undefined) {
// set field value extracted from `auth()`
data[fieldInfo.name] = authDefaultValue;
}
}
};

// visit create payload and set default value to fields using `auth()` in `@default()`
const visitor = new NestedWriteVisitor(this.options.modelMeta, {
create: (model, data) => {
processCreatePayload(model, data);
},

createMany: (model, args) => {
for (const item of enumerate(args.data)) {
processCreatePayload(model, item);
}
},
});

await visitor.visit(model, action, newArgs);
return newArgs;
}

private getDefaultValueFromAuth(fieldInfo: FieldInfo) {
return fieldInfo.defaultValueProvider?.(this.userContext);
}
}
15 changes: 15 additions & 0 deletions packages/runtime/src/enhancements/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,18 @@ export function prismaClientKnownRequestError(prisma: DbClientContract, prismaMo
export function prismaClientUnknownRequestError(prismaModule: any, ...args: unknown[]): Error {
throw new prismaModule.PrismaClientUnknownRequestError(...args);
}

export function deepGet(object: object, path: string | string[] | undefined, defaultValue: unknown): unknown {
if (path === undefined || path === '') {
return defaultValue;
}
const keys = Array.isArray(path) ? path : path.split('.');
for (const key of keys) {
if (object && typeof object === 'object' && key in object) {
object = object[key as keyof typeof object];
} else {
return defaultValue;
}
}
return object !== undefined ? object : defaultValue;
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@ import {
Expression,
ExpressionType,
isDataModel,
isDataModelField,
isEnum,
isLiteralExpr,
isMemberAccessExpr,
isNullExpr,
isThisExpr,
isDataModelField,
isLiteralExpr,
} from '@zenstackhq/language/ast';
import { isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk';
import { isAuthInvocation, isDataModelFieldReference, isEnumFieldReference } from '@zenstackhq/sdk';
import { ValidationAcceptor } from 'langium';
import { getContainingDataModel, isAuthInvocation, isCollectionPredicate } from '../../utils/ast-utils';
import { getContainingDataModel, isCollectionPredicate } from '../../utils/ast-utils';
import { AstValidator } from '../types';
import { typeAssignable } from './utils';

Expand Down Expand Up @@ -132,18 +132,24 @@ export default class ExpressionValidator implements AstValidator<Expression> {
// - foo.user.id == userId
// except:
// - future().userId == userId
if(isMemberAccessExpr(expr.left) && isDataModelField(expr.left.member.ref) && expr.left.member.ref.$container != getContainingDataModel(expr)
|| isMemberAccessExpr(expr.right) && isDataModelField(expr.right.member.ref) && expr.right.member.ref.$container != getContainingDataModel(expr))
{
if (
(isMemberAccessExpr(expr.left) &&
isDataModelField(expr.left.member.ref) &&
expr.left.member.ref.$container != getContainingDataModel(expr)) ||
(isMemberAccessExpr(expr.right) &&
isDataModelField(expr.right.member.ref) &&
expr.right.member.ref.$container != getContainingDataModel(expr))
) {
// foo.user.id == auth().id
// foo.user.id == "123"
// foo.user.id == null
// foo.user.id == EnumValue
if(!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right)))
{
accept('error', 'comparison between fields of different models are not supported', { node: expr });
break;
}
if (!(this.isNotModelFieldExpr(expr.left) || this.isNotModelFieldExpr(expr.right))) {
accept('error', 'comparison between fields of different models are not supported', {
node: expr,
});
break;
}
}

if (
Expand Down Expand Up @@ -205,14 +211,13 @@ export default class ExpressionValidator implements AstValidator<Expression> {
}
}


private isNotModelFieldExpr(expr: Expression) {
return isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr)
return (
isLiteralExpr(expr) || isEnumFieldReference(expr) || isNullExpr(expr) || this.isAuthOrAuthMemberAccess(expr)
);
}

private isAuthOrAuthMemberAccess(expr: Expression) {
return isAuthInvocation(expr) || (isMemberAccessExpr(expr) && isAuthInvocation(expr.operand));
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,15 @@ import {
isDataModelFieldAttribute,
isLiteralExpr,
} from '@zenstackhq/language/ast';
import { ExpressionContext, getFunctionExpressionContext, isEnumFieldReference, isFromStdlib } from '@zenstackhq/sdk';
import {
ExpressionContext,
getDataModelFieldReference,
getFunctionExpressionContext,
isEnumFieldReference,
isFromStdlib,
} from '@zenstackhq/sdk';
import { AstNode, ValidationAcceptor } from 'langium';
import { P, match } from 'ts-pattern';
import { getDataModelFieldReference } from '../../utils/ast-utils';
import { AstValidator } from '../types';
import { typeAssignable } from './utils';

Expand Down
13 changes: 4 additions & 9 deletions packages/schema/src/language-server/zmodel-linker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import {
isReferenceExpr,
isStringLiteral,
} from '@zenstackhq/language/ast';
import { getContainingModel, hasAttribute, isFromStdlib } from '@zenstackhq/sdk';
import { getContainingModel, hasAttribute, isAuthInvocation, isFutureExpr } from '@zenstackhq/sdk';
import {
AstNode,
AstNodeDescription,
Expand All @@ -52,12 +52,7 @@ import {
} from 'langium';
import { match } from 'ts-pattern';
import { CancellationToken } from 'vscode-jsonrpc';
import {
getAllDeclarationsFromImports,
getContainingDataModel,
isAuthInvocation,
isCollectionPredicate,
} from '../utils/ast-utils';
import { getAllDeclarationsFromImports, getContainingDataModel, isCollectionPredicate } from '../utils/ast-utils';
import { mapBuiltinTypeToExpressionType } from './validator/utils';

interface DefaultReference extends Reference {
Expand Down Expand Up @@ -329,7 +324,7 @@ export class ZModelLinker extends DefaultLinker {
if (node.function.ref) {
// eslint-disable-next-line @typescript-eslint/ban-types
const funcDecl = node.function.ref as FunctionDecl;
if (funcDecl.name === 'auth' && isFromStdlib(funcDecl)) {
if (isAuthInvocation(node)) {
// auth() function is resolved to User model in the current document
const model = getContainingModel(node);

Expand All @@ -346,7 +341,7 @@ export class ZModelLinker extends DefaultLinker {
node.$resolvedType = { decl: authModel, nullable: true };
}
}
} else if (funcDecl.name === 'future' && isFromStdlib(funcDecl)) {
} else if (isFutureExpr(node)) {
// future() function is resolved to current model
node.$resolvedType = { decl: getContainingDataModel(node) };
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,18 @@ import {
import {
ExpressionContext,
getFunctionExpressionContext,
getIdFields,
getLiteral,
isAuthInvocation,
isDataModelFieldReference,
isFutureExpr,
PluginError,
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
} from '@zenstackhq/sdk';
import { lowerCaseFirst } from 'lower-case-first';
import { CodeBlockWriter } from 'ts-morph';
import { name } from '..';
import { getIdFields, isAuthInvocation } from '../../../utils/ast-utils';
import {
TypeScriptExpressionTransformer,
TypeScriptExpressionTransformerError,
} from '../../../utils/typescript-expression-transformer';

type ComparisonOperator = '==' | '!=' | '>' | '>=' | '<' | '<=';

Expand Down
Loading