Skip to content

feat: support multi-id-field models (@@id([f1, f2, ...])) #243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "zenstack-monorepo",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"description": "",
"scripts": {
"build": "pnpm -r build",
Expand Down
2 changes: 1 addition & 1 deletion packages/language/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/language",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"displayName": "ZenStack modeling language compiler",
"description": "ZenStack modeling language compiler",
"homepage": "https://zenstack.dev",
Expand Down
3 changes: 2 additions & 1 deletion packages/next/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@zenstackhq/next",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"displayName": "ZenStack Next.js integration",
"description": "ZenStack Next.js integration",
"homepage": "https://zenstack.dev",
Expand All @@ -9,6 +9,7 @@
"build": "pnpm lint && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE dist",
"watch": "tsc --watch",
"lint": "eslint src --ext ts",
"test": "jest",
"prepublishOnly": "pnpm build",
"publish-dev": "pnpm publish --tag dev"
},
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/react/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/react",
"displayName": "ZenStack plugin and runtime for ReactJS",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"description": "ZenStack plugin and runtime for ReactJS",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/plugins/trpc/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/trpc",
"displayName": "ZenStack plugin for tRPC",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"description": "ZenStack plugin for tRPC",
"main": "index.js",
"repository": {
Expand Down
2 changes: 1 addition & 1 deletion packages/runtime/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@zenstackhq/runtime",
"displayName": "ZenStack Runtime Library",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"description": "Runtime of ZenStack for both client-side and server-side environments.",
"repository": {
"type": "git",
Expand Down
25 changes: 16 additions & 9 deletions packages/runtime/src/enhancements/policy/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
dbOps.create(writeArgs)
);

if (!this.utils.getEntityId(this.model, result)) {
const ids = this.utils.getEntityIds(this.model, result);
if (Object.keys(ids).length === 0) {
throw this.utils.unknownError(`unexpected error: create didn't return an id`);
}

return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'create', 'create');
return this.checkReadback(origArgs, ids, 'create', 'create');
}

async createMany(args: any, skipDuplicates?: boolean) {
Expand Down Expand Up @@ -136,10 +137,11 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
dbOps.update(writeArgs)
);

if (!this.utils.getEntityId(this.model, result)) {
const ids = this.utils.getEntityIds(this.model, result);
if (Object.keys(ids).length === 0) {
throw this.utils.unknownError(`unexpected error: update didn't return an id`);
}
return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'update', 'update');
return this.checkReadback(origArgs, ids, 'update', 'update');
}

async updateMany(args: any) {
Expand Down Expand Up @@ -189,11 +191,12 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
dbOps.upsert(writeArgs)
);

if (!this.utils.getEntityId(this.model, result)) {
const ids = this.utils.getEntityIds(this.model, result);
if (Object.keys(ids).length === 0) {
throw this.utils.unknownError(`unexpected error: upsert didn't return an id`);
}

return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'upsert', 'update');
return this.checkReadback(origArgs, ids, 'upsert', 'update');
}

async delete(args: any) {
Expand Down Expand Up @@ -283,9 +286,13 @@ export class PolicyProxyHandler<DbClient extends DbClientContract> implements Pr
}
}

private async checkReadback(origArgs: any, id: any, action: string, operation: PolicyOperationKind) {
const idField = this.utils.getIdField(this.model);
const readArgs = { select: origArgs.select, include: origArgs.include, where: { [idField.name]: id } };
private async checkReadback(
origArgs: any,
ids: Record<string, unknown>,
action: string,
operation: PolicyOperationKind
) {
const readArgs = { select: origArgs.select, include: origArgs.include, where: ids };
const result = await this.utils.readWithCheck(this.model, readArgs);
if (result.length === 0) {
this.logger.warn(`${action} result cannot be read back`);
Expand Down
102 changes: 65 additions & 37 deletions packages/runtime/src/enhancements/policy/policy-utils.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @typescript-eslint/no-explicit-any */

import { PrismaClientKnownRequestError, PrismaClientUnknownRequestError } from '@prisma/client/runtime';
import { AUXILIARY_FIELDS, CrudFailureReason, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk';
import { AUXILIARY_FIELDS, CrudFailureReason, GUARD_FIELD_NAME, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk';
import { camelCase } from 'change-case';
import cuid from 'cuid';
import deepcopy from 'deepcopy';
Expand Down Expand Up @@ -42,8 +42,7 @@ export class PolicyUtil {
and(...conditions: (boolean | object)[]): any {
if (conditions.includes(false)) {
// always false
// TODO: custom id field
return { id: { in: [] } };
return { [GUARD_FIELD_NAME]: false };
}

const filtered = conditions.filter(
Expand All @@ -64,7 +63,7 @@ export class PolicyUtil {
or(...conditions: (boolean | object)[]): any {
if (conditions.includes(true)) {
// always true
return { id: { notIn: [] } };
return { [GUARD_FIELD_NAME]: true };
}

const filtered = conditions.filter((c): c is object => typeof c === 'object' && !!c);
Expand Down Expand Up @@ -276,7 +275,7 @@ export class PolicyUtil {
return;
}

const idField = this.getIdField(model);
const idFields = this.getIdFields(model);
for (const field of getModelFields(injectTarget)) {
const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo || !fieldInfo.isDataModel) {
Expand All @@ -292,10 +291,16 @@ export class PolicyUtil {

await this.injectAuthGuard(injectTarget[field], fieldInfo.type, 'read');
} else {
// there's no way of injecting condition for to-one relation, so we
// make sure 'id' field is selected and check them against query result
if (injectTarget[field]?.select && injectTarget[field]?.select?.[idField.name] !== true) {
injectTarget[field].select[idField.name] = true;
// there's no way of injecting condition for to-one relation, so if there's
// "select" clause we make sure 'id' fields are selected and check them against
// query result; nothing needs to be done for "include" clause because all
// fields are already selected
if (injectTarget[field]?.select) {
for (const idField of idFields) {
if (injectTarget[field].select[idField.name] !== true) {
injectTarget[field].select[idField.name] = true;
}
}
}
}

Expand All @@ -310,7 +315,8 @@ export class PolicyUtil {
* omitted.
*/
async postProcessForRead(entityData: any, model: string, args: any, operation: PolicyOperationKind) {
if (!this.getEntityId(model, entityData)) {
const ids = this.getEntityIds(model, entityData);
if (Object.keys(ids).length === 0) {
return;
}

Expand All @@ -330,21 +336,23 @@ export class PolicyUtil {
// post-check them

for (const field of getModelFields(injectTarget)) {
if (!entityData?.[field]) {
continue;
}

const fieldInfo = resolveField(this.modelMeta, model, field);
if (!fieldInfo || !fieldInfo.isDataModel || fieldInfo.isArray) {
continue;
}

const idField = this.getIdField(fieldInfo.type);
const relatedEntityId = entityData?.[field]?.[idField.name];
const ids = this.getEntityIds(fieldInfo.type, entityData[field]);

if (!relatedEntityId) {
if (Object.keys(ids).length === 0) {
continue;
}

this.logger.info(`Validating read of to-one relation: ${fieldInfo.type}#${relatedEntityId}`);

await this.checkPolicyForFilter(fieldInfo.type, { [idField.name]: relatedEntityId }, operation, this.db);
this.logger.info(`Validating read of to-one relation: ${fieldInfo.type}#${formatObject(ids)}`);
await this.checkPolicyForFilter(fieldInfo.type, ids, operation, this.db);

// recurse
await this.postProcessForRead(entityData[field], fieldInfo.type, injectTarget[field], operation);
Expand All @@ -366,14 +374,18 @@ export class PolicyUtil {

// record model entities that are updated, together with their
// values before update, so we can post-check if they satisfy
// model => id => entity value
const updatedModels = new Map<string, Map<string, any>>();
// model => { ids, entity value }
const updatedModels = new Map<string, Array<{ ids: Record<string, unknown>; value: any }>>();

const idField = this.getIdField(model);
if (args.select && !args.select[idField.name]) {
const idFields = this.getIdFields(model);
if (args.select) {
// make sure 'id' field is selected, we need it to
// read back the updated entity
args.select[idField.name] = true;
for (const idField of idFields) {
if (!args.select[idField.name]) {
args.select[idField.name] = true;
}
}
}

// use a transaction to conduct write, so in case any create or nested create
Expand Down Expand Up @@ -496,7 +508,7 @@ export class PolicyUtil {
if (postGuard !== true || schema) {
let modelEntities = updatedModels.get(model);
if (!modelEntities) {
modelEntities = new Map<string, any>();
modelEntities = [];
updatedModels.set(model, modelEntities);
}

Expand All @@ -509,11 +521,19 @@ export class PolicyUtil {
// e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' }
await this.flattenGeneratedUniqueField(model, filter);

const idField = this.getIdField(model);
const query = { where: filter, select: { ...preValueSelect, [idField.name]: true } };
const idFields = this.getIdFields(model);
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const select: any = { ...preValueSelect };
for (const idField of idFields) {
select[idField.name] = true;
}

const query = { where: filter, select };
this.logger.info(`fetching pre-update entities for ${model}: ${formatObject(query)})}`);
const entities = await this.db[model].findMany(query);
entities.forEach((entity) => modelEntities?.set(this.getEntityId(model, entity), entity));
entities.forEach((entity) =>
modelEntities?.push({ ids: this.getEntityIds(model, entity), value: entity })
);
}
};

Expand Down Expand Up @@ -622,8 +642,8 @@ export class PolicyUtil {
await Promise.all(
[...updatedModels.entries()]
.map(([model, modelEntities]) =>
[...modelEntities.entries()].map(async ([id, preValue]) =>
this.checkPostUpdate(model, id, tx, preValue)
modelEntities.map(async ({ ids, value: preValue }) =>
this.checkPostUpdate(model, ids, tx, preValue)
)
)
.flat()
Expand Down Expand Up @@ -716,14 +736,18 @@ export class PolicyUtil {
}
}

private async checkPostUpdate(model: string, id: any, db: Record<string, DbOperations>, preValue: any) {
this.logger.info(`Checking post-update policy for ${model}#${id}, preValue: ${formatObject(preValue)}`);
private async checkPostUpdate(
model: string,
ids: Record<string, unknown>,
db: Record<string, DbOperations>,
preValue: any
) {
this.logger.info(`Checking post-update policy for ${model}#${ids}, preValue: ${formatObject(preValue)}`);

const guard = await this.getAuthGuard(model, 'postUpdate', preValue);

// build a query condition with policy injected
const idField = this.getIdField(model);
const guardedQuery = { where: this.and({ [idField.name]: id }, guard) };
const guardedQuery = { where: this.and(ids, guard) };

// query with policy injected
const entity = await db[model].findFirst(guardedQuery);
Expand Down Expand Up @@ -760,13 +784,13 @@ export class PolicyUtil {
/**
* Gets "id" field for a given model.
*/
getIdField(model: string) {
getIdFields(model: string) {
const fields = this.modelMeta.fields[camelCase(model)];
if (!fields) {
throw this.unknownError(`Unable to load fields for ${model}`);
}
const result = Object.values(fields).find((f) => f.isId);
if (!result) {
const result = Object.values(fields).filter((f) => f.isId);
if (result.length === 0) {
throw this.unknownError(`model ${model} does not have an id field`);
}
return result;
Expand All @@ -775,8 +799,12 @@ export class PolicyUtil {
/**
* Gets id field value from an entity.
*/
getEntityId(model: string, entityData: any) {
const idField = this.getIdField(model);
return entityData[idField.name];
getEntityIds(model: string, entityData: any) {
const idFields = this.getIdFields(model);
const result: Record<string, unknown> = {};
for (const idField of idFields) {
result[idField.name] = entityData[idField.name];
}
return result;
}
}
2 changes: 1 addition & 1 deletion packages/schema/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"publisher": "zenstack",
"displayName": "ZenStack Language Tools",
"description": "A toolkit for building secure CRUD apps with Next.js + Typescript",
"version": "1.0.0-alpha.56",
"version": "1.0.0-alpha.57",
"author": {
"name": "ZenStack Team"
},
Expand Down
Loading