From d9efa9535dfa5ecb68d6e8ba3f722596bd30fb4d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Mon, 6 Mar 2023 19:27:51 +0800 Subject: [PATCH] feat: support multi-id-field models (@@id([f1, f2, ...])) --- package.json | 2 +- packages/language/package.json | 2 +- packages/next/package.json | 3 +- packages/plugins/react/package.json | 2 +- packages/plugins/trpc/package.json | 2 +- packages/runtime/package.json | 2 +- .../src/enhancements/policy/handler.ts | 25 +++-- .../src/enhancements/policy/policy-utils.ts | 102 +++++++++++------- packages/schema/package.json | 2 +- packages/schema/src/language-server/utils.ts | 46 +++++++- .../validator/datamodel-validator.ts | 44 +++++--- .../schema/src/plugins/model-meta/index.ts | 18 +++- .../src/plugins/prisma/schema-generator.ts | 14 +-- packages/schema/src/res/starter.zmodel | 4 + .../validation/datamodel-validation.test.ts | 64 ++++++++++- packages/sdk/package.json | 2 +- packages/testtools/package.json | 2 +- tests/integration/test-run/package-lock.json | 4 +- .../tests/with-policy/multi-id-fields.test.ts | 71 ++++++++++++ 19 files changed, 326 insertions(+), 85 deletions(-) create mode 100644 tests/integration/tests/with-policy/multi-id-fields.test.ts diff --git a/package.json b/package.json index 45000fb8d..230040c20 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "zenstack-monorepo", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "", "scripts": { "build": "pnpm -r build", diff --git a/packages/language/package.json b/packages/language/package.json index 9255bc83e..44249884b 100644 --- a/packages/language/package.json +++ b/packages/language/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/language", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "displayName": "ZenStack modeling language compiler", "description": "ZenStack modeling language compiler", "homepage": "https://zenstack.dev", diff --git a/packages/next/package.json b/packages/next/package.json index b77b58744..31536a73f 100644 --- a/packages/next/package.json +++ b/packages/next/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/next", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "displayName": "ZenStack Next.js integration", "description": "ZenStack Next.js integration", "homepage": "https://zenstack.dev", @@ -9,6 +9,7 @@ "build": "pnpm lint && pnpm clean && tsc && copyfiles ./package.json ./README.md ./LICENSE dist", "watch": "tsc --watch", "lint": "eslint src --ext ts", + "test": "jest", "prepublishOnly": "pnpm build", "publish-dev": "pnpm publish --tag dev" }, diff --git a/packages/plugins/react/package.json b/packages/plugins/react/package.json index 258476b26..fc9e28759 100644 --- a/packages/plugins/react/package.json +++ b/packages/plugins/react/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/react", "displayName": "ZenStack plugin and runtime for ReactJS", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "ZenStack plugin and runtime for ReactJS", "main": "index.js", "repository": { diff --git a/packages/plugins/trpc/package.json b/packages/plugins/trpc/package.json index 7d652358f..b48eff4e0 100644 --- a/packages/plugins/trpc/package.json +++ b/packages/plugins/trpc/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/trpc", "displayName": "ZenStack plugin for tRPC", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "ZenStack plugin for tRPC", "main": "index.js", "repository": { diff --git a/packages/runtime/package.json b/packages/runtime/package.json index 9bd73738f..a1d445a37 100644 --- a/packages/runtime/package.json +++ b/packages/runtime/package.json @@ -1,7 +1,7 @@ { "name": "@zenstackhq/runtime", "displayName": "ZenStack Runtime Library", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "Runtime of ZenStack for both client-side and server-side environments.", "repository": { "type": "git", diff --git a/packages/runtime/src/enhancements/policy/handler.ts b/packages/runtime/src/enhancements/policy/handler.ts index 03a4e37f1..272cbbb70 100644 --- a/packages/runtime/src/enhancements/policy/handler.ts +++ b/packages/runtime/src/enhancements/policy/handler.ts @@ -86,11 +86,12 @@ export class PolicyProxyHandler implements Pr dbOps.create(writeArgs) ); - if (!this.utils.getEntityId(this.model, result)) { + const ids = this.utils.getEntityIds(this.model, result); + if (Object.keys(ids).length === 0) { throw this.utils.unknownError(`unexpected error: create didn't return an id`); } - return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'create', 'create'); + return this.checkReadback(origArgs, ids, 'create', 'create'); } async createMany(args: any, skipDuplicates?: boolean) { @@ -136,10 +137,11 @@ export class PolicyProxyHandler implements Pr dbOps.update(writeArgs) ); - if (!this.utils.getEntityId(this.model, result)) { + const ids = this.utils.getEntityIds(this.model, result); + if (Object.keys(ids).length === 0) { throw this.utils.unknownError(`unexpected error: update didn't return an id`); } - return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'update', 'update'); + return this.checkReadback(origArgs, ids, 'update', 'update'); } async updateMany(args: any) { @@ -189,11 +191,12 @@ export class PolicyProxyHandler implements Pr dbOps.upsert(writeArgs) ); - if (!this.utils.getEntityId(this.model, result)) { + const ids = this.utils.getEntityIds(this.model, result); + if (Object.keys(ids).length === 0) { throw this.utils.unknownError(`unexpected error: upsert didn't return an id`); } - return this.checkReadback(origArgs, this.utils.getEntityId(this.model, result), 'upsert', 'update'); + return this.checkReadback(origArgs, ids, 'upsert', 'update'); } async delete(args: any) { @@ -283,9 +286,13 @@ export class PolicyProxyHandler implements Pr } } - private async checkReadback(origArgs: any, id: any, action: string, operation: PolicyOperationKind) { - const idField = this.utils.getIdField(this.model); - const readArgs = { select: origArgs.select, include: origArgs.include, where: { [idField.name]: id } }; + private async checkReadback( + origArgs: any, + ids: Record, + action: string, + operation: PolicyOperationKind + ) { + const readArgs = { select: origArgs.select, include: origArgs.include, where: ids }; const result = await this.utils.readWithCheck(this.model, readArgs); if (result.length === 0) { this.logger.warn(`${action} result cannot be read back`); diff --git a/packages/runtime/src/enhancements/policy/policy-utils.ts b/packages/runtime/src/enhancements/policy/policy-utils.ts index 4c0ec1933..452ff71e8 100644 --- a/packages/runtime/src/enhancements/policy/policy-utils.ts +++ b/packages/runtime/src/enhancements/policy/policy-utils.ts @@ -1,7 +1,7 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ import { PrismaClientKnownRequestError, PrismaClientUnknownRequestError } from '@prisma/client/runtime'; -import { AUXILIARY_FIELDS, CrudFailureReason, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk'; +import { AUXILIARY_FIELDS, CrudFailureReason, GUARD_FIELD_NAME, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk'; import { camelCase } from 'change-case'; import cuid from 'cuid'; import deepcopy from 'deepcopy'; @@ -42,8 +42,7 @@ export class PolicyUtil { and(...conditions: (boolean | object)[]): any { if (conditions.includes(false)) { // always false - // TODO: custom id field - return { id: { in: [] } }; + return { [GUARD_FIELD_NAME]: false }; } const filtered = conditions.filter( @@ -64,7 +63,7 @@ export class PolicyUtil { or(...conditions: (boolean | object)[]): any { if (conditions.includes(true)) { // always true - return { id: { notIn: [] } }; + return { [GUARD_FIELD_NAME]: true }; } const filtered = conditions.filter((c): c is object => typeof c === 'object' && !!c); @@ -276,7 +275,7 @@ export class PolicyUtil { return; } - const idField = this.getIdField(model); + const idFields = this.getIdFields(model); for (const field of getModelFields(injectTarget)) { const fieldInfo = resolveField(this.modelMeta, model, field); if (!fieldInfo || !fieldInfo.isDataModel) { @@ -292,10 +291,16 @@ export class PolicyUtil { await this.injectAuthGuard(injectTarget[field], fieldInfo.type, 'read'); } else { - // there's no way of injecting condition for to-one relation, so we - // make sure 'id' field is selected and check them against query result - if (injectTarget[field]?.select && injectTarget[field]?.select?.[idField.name] !== true) { - injectTarget[field].select[idField.name] = true; + // there's no way of injecting condition for to-one relation, so if there's + // "select" clause we make sure 'id' fields are selected and check them against + // query result; nothing needs to be done for "include" clause because all + // fields are already selected + if (injectTarget[field]?.select) { + for (const idField of idFields) { + if (injectTarget[field].select[idField.name] !== true) { + injectTarget[field].select[idField.name] = true; + } + } } } @@ -310,7 +315,8 @@ export class PolicyUtil { * omitted. */ async postProcessForRead(entityData: any, model: string, args: any, operation: PolicyOperationKind) { - if (!this.getEntityId(model, entityData)) { + const ids = this.getEntityIds(model, entityData); + if (Object.keys(ids).length === 0) { return; } @@ -330,21 +336,23 @@ export class PolicyUtil { // post-check them for (const field of getModelFields(injectTarget)) { + if (!entityData?.[field]) { + continue; + } + const fieldInfo = resolveField(this.modelMeta, model, field); if (!fieldInfo || !fieldInfo.isDataModel || fieldInfo.isArray) { continue; } - const idField = this.getIdField(fieldInfo.type); - const relatedEntityId = entityData?.[field]?.[idField.name]; + const ids = this.getEntityIds(fieldInfo.type, entityData[field]); - if (!relatedEntityId) { + if (Object.keys(ids).length === 0) { continue; } - this.logger.info(`Validating read of to-one relation: ${fieldInfo.type}#${relatedEntityId}`); - - await this.checkPolicyForFilter(fieldInfo.type, { [idField.name]: relatedEntityId }, operation, this.db); + this.logger.info(`Validating read of to-one relation: ${fieldInfo.type}#${formatObject(ids)}`); + await this.checkPolicyForFilter(fieldInfo.type, ids, operation, this.db); // recurse await this.postProcessForRead(entityData[field], fieldInfo.type, injectTarget[field], operation); @@ -366,14 +374,18 @@ export class PolicyUtil { // record model entities that are updated, together with their // values before update, so we can post-check if they satisfy - // model => id => entity value - const updatedModels = new Map>(); + // model => { ids, entity value } + const updatedModels = new Map; value: any }>>(); - const idField = this.getIdField(model); - if (args.select && !args.select[idField.name]) { + const idFields = this.getIdFields(model); + if (args.select) { // make sure 'id' field is selected, we need it to // read back the updated entity - args.select[idField.name] = true; + for (const idField of idFields) { + if (!args.select[idField.name]) { + args.select[idField.name] = true; + } + } } // use a transaction to conduct write, so in case any create or nested create @@ -496,7 +508,7 @@ export class PolicyUtil { if (postGuard !== true || schema) { let modelEntities = updatedModels.get(model); if (!modelEntities) { - modelEntities = new Map(); + modelEntities = []; updatedModels.set(model, modelEntities); } @@ -509,11 +521,19 @@ export class PolicyUtil { // e.g.: { a_b: { a: '1', b: '1' } } => { a: '1', b: '1' } await this.flattenGeneratedUniqueField(model, filter); - const idField = this.getIdField(model); - const query = { where: filter, select: { ...preValueSelect, [idField.name]: true } }; + const idFields = this.getIdFields(model); + // eslint-disable-next-line @typescript-eslint/no-unused-vars + const select: any = { ...preValueSelect }; + for (const idField of idFields) { + select[idField.name] = true; + } + + const query = { where: filter, select }; this.logger.info(`fetching pre-update entities for ${model}: ${formatObject(query)})}`); const entities = await this.db[model].findMany(query); - entities.forEach((entity) => modelEntities?.set(this.getEntityId(model, entity), entity)); + entities.forEach((entity) => + modelEntities?.push({ ids: this.getEntityIds(model, entity), value: entity }) + ); } }; @@ -622,8 +642,8 @@ export class PolicyUtil { await Promise.all( [...updatedModels.entries()] .map(([model, modelEntities]) => - [...modelEntities.entries()].map(async ([id, preValue]) => - this.checkPostUpdate(model, id, tx, preValue) + modelEntities.map(async ({ ids, value: preValue }) => + this.checkPostUpdate(model, ids, tx, preValue) ) ) .flat() @@ -716,14 +736,18 @@ export class PolicyUtil { } } - private async checkPostUpdate(model: string, id: any, db: Record, preValue: any) { - this.logger.info(`Checking post-update policy for ${model}#${id}, preValue: ${formatObject(preValue)}`); + private async checkPostUpdate( + model: string, + ids: Record, + db: Record, + preValue: any + ) { + this.logger.info(`Checking post-update policy for ${model}#${ids}, preValue: ${formatObject(preValue)}`); const guard = await this.getAuthGuard(model, 'postUpdate', preValue); // build a query condition with policy injected - const idField = this.getIdField(model); - const guardedQuery = { where: this.and({ [idField.name]: id }, guard) }; + const guardedQuery = { where: this.and(ids, guard) }; // query with policy injected const entity = await db[model].findFirst(guardedQuery); @@ -760,13 +784,13 @@ export class PolicyUtil { /** * Gets "id" field for a given model. */ - getIdField(model: string) { + getIdFields(model: string) { const fields = this.modelMeta.fields[camelCase(model)]; if (!fields) { throw this.unknownError(`Unable to load fields for ${model}`); } - const result = Object.values(fields).find((f) => f.isId); - if (!result) { + const result = Object.values(fields).filter((f) => f.isId); + if (result.length === 0) { throw this.unknownError(`model ${model} does not have an id field`); } return result; @@ -775,8 +799,12 @@ export class PolicyUtil { /** * Gets id field value from an entity. */ - getEntityId(model: string, entityData: any) { - const idField = this.getIdField(model); - return entityData[idField.name]; + getEntityIds(model: string, entityData: any) { + const idFields = this.getIdFields(model); + const result: Record = {}; + for (const idField of idFields) { + result[idField.name] = entityData[idField.name]; + } + return result; } } diff --git a/packages/schema/package.json b/packages/schema/package.json index 6fa920912..f864b7251 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -3,7 +3,7 @@ "publisher": "zenstack", "displayName": "ZenStack Language Tools", "description": "A toolkit for building secure CRUD apps with Next.js + Typescript", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "author": { "name": "ZenStack Team" }, diff --git a/packages/schema/src/language-server/utils.ts b/packages/schema/src/language-server/utils.ts index 261c4815e..dcf6a9f37 100644 --- a/packages/schema/src/language-server/utils.ts +++ b/packages/schema/src/language-server/utils.ts @@ -1,6 +1,15 @@ import { AstNode } from 'langium'; import { STD_LIB_MODULE_NAME } from './constants'; -import { isModel, Model } from '@zenstackhq/language/ast'; +import { + DataModel, + DataModelField, + isArrayExpr, + isModel, + isReferenceExpr, + Model, + ReferenceExpr, +} from '@zenstackhq/language/ast'; +import { resolved } from '@zenstackhq/sdk'; /** * Gets the toplevel Model containing the given node. @@ -19,3 +28,38 @@ export function isFromStdlib(node: AstNode) { const model = getContainingModel(node); return !!model && !!model.$document && model.$document.uri.path.endsWith(STD_LIB_MODULE_NAME); } + +/** + * Gets id fields declared at the data model level + */ +export function getIdFields(model: DataModel) { + const idAttr = model.attributes.find((attr) => attr.decl.ref?.name === '@@id'); + if (!idAttr) { + return []; + } + const fieldsArg = idAttr.args.find((a) => a.$resolvedParam?.name === 'fields'); + if (!fieldsArg || !isArrayExpr(fieldsArg.value)) { + return []; + } + + return fieldsArg.value.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => resolved(item.target) as DataModelField); +} + +/** + * Gets lists of unique fields declared at the data model level + */ +export function getUniqueFields(model: DataModel) { + const uniqueAttrs = model.attributes.filter((attr) => attr.decl.ref?.name === '@@unique'); + return uniqueAttrs.map((uniqueAttr) => { + const fieldsArg = uniqueAttr.args.find((a) => a.$resolvedParam?.name === 'fields'); + if (!fieldsArg || !isArrayExpr(fieldsArg.value)) { + return []; + } + + return fieldsArg.value.items + .filter((item): item is ReferenceExpr => isReferenceExpr(item)) + .map((item) => resolved(item.target) as DataModelField); + }); +} diff --git a/packages/schema/src/language-server/validator/datamodel-validator.ts b/packages/schema/src/language-server/validator/datamodel-validator.ts index 9a645d484..873f7935d 100644 --- a/packages/schema/src/language-server/validator/datamodel-validator.ts +++ b/packages/schema/src/language-server/validator/datamodel-validator.ts @@ -10,6 +10,7 @@ import { ValidationAcceptor } from 'langium'; import { analyzePolicies } from '../../utils/ast-utils'; import { IssueCodes, SCALAR_TYPES } from '../constants'; import { AstValidator } from '../types'; +import { getIdFields, getUniqueFields } from '../utils'; import { validateAttributeApplication, validateDuplicatedDeclarations } from './utils'; /** @@ -18,35 +19,41 @@ import { validateAttributeApplication, validateDuplicatedDeclarations } from './ export default class DataModelValidator implements AstValidator { validate(dm: DataModel, accept: ValidationAcceptor): void { validateDuplicatedDeclarations(dm.fields, accept); - this.validateFields(dm, accept); this.validateAttributes(dm, accept); + this.validateFields(dm, accept); } private validateFields(dm: DataModel, accept: ValidationAcceptor) { - // TODO: check conflict of @id and @@id - const idFields = dm.fields.filter((f) => f.attributes.find((attr) => attr.decl.ref?.name === '@id')); - if (idFields.length === 0) { + const modelLevelIds = getIdFields(dm); + + if (idFields.length === 0 && modelLevelIds.length === 0) { const { allows, denies, hasFieldValidation } = analyzePolicies(dm); if (allows.length > 0 || denies.length > 0 || hasFieldValidation) { // TODO: relax this requirement to require only @unique fields // when access policies or field valdaition is used, require an @id field - accept('error', 'Model must include a field with @id attribute', { + accept('error', 'Model must include a field with @id attribute or a model-level @@id attribute', { node: dm, }); } + } else if (idFields.length > 0 && modelLevelIds.length > 0) { + accept('error', 'Model cannot have both field-level @id and model-level @@id attributes', { + node: dm, + }); } else if (idFields.length > 1) { accept('error', 'Model can include at most one field with @id attribute', { node: dm, }); } else { - if (idFields[0].type.optional) { - accept('error', 'Field with @id attribute must not be optional', { node: idFields[0] }); - } - - if (idFields[0].type.array || !idFields[0].type.type || !SCALAR_TYPES.includes(idFields[0].type.type)) { - accept('error', 'Field with @id attribute must be of scalar type', { node: idFields[0] }); - } + const fieldsToCheck = idFields.length > 0 ? idFields : modelLevelIds; + fieldsToCheck.forEach((idField) => { + if (idField.type.optional) { + accept('error', 'Field with @id attribute must not be optional', { node: idField }); + } + if (idField.type.array || !idField.type.type || !SCALAR_TYPES.includes(idField.type.type)) { + accept('error', 'Field with @id attribute must be of scalar type', { node: idField }); + } + }); } dm.fields.forEach((field) => this.validateField(field, accept)); @@ -241,12 +248,21 @@ export default class DataModelValidator implements AstValidator { // // UserData.userId field needs to be @unique + const containingModel = field.$container as DataModel; + const uniqueFieldList = getUniqueFields(containingModel); + thisRelation.fields?.forEach((ref) => { const refField = ref.target.ref as DataModelField; - if (refField && !refField.attributes.find((a) => a.decl.ref?.name === '@unique')) { + if (refField) { + if (refField.attributes.find((a) => a.decl.ref?.name === '@unique')) { + return; + } + if (uniqueFieldList.some((list) => list.includes(refField))) { + return; + } accept( 'error', - `Field "${refField.name}" is part of a one-to-one relation and must be marked as @unique`, + `Field "${refField.name}" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute`, { node: refField } ); } diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index 13794805a..43d6a255a 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -12,6 +12,7 @@ import { getAttributeArgs, getLiteral, PluginOptions, resolved } from '@zenstack import { camelCase } from 'change-case'; import path from 'path'; import { CodeBlockWriter, Project, VariableDeclarationKind } from 'ts-morph'; +import { getIdFields } from '../../language-server/utils'; import { ensureNodeModuleFolder, getDefaultOutputFolder } from '../plugin-utils'; export const name = 'Model Metadata'; @@ -142,12 +143,25 @@ function getFieldAttributes(field: DataModelField): RuntimeAttribute[] { } function isIdField(field: DataModelField) { - return field.attributes.some((attr) => attr.decl.ref?.name === '@id'); + // field-level @id attribute + if (field.attributes.some((attr) => attr.decl.ref?.name === '@id')) { + return true; + } + + // model-level @@id attribute with a list of fields + const model = field.$container as DataModel; + const modelLevelIds = getIdFields(model); + if (modelLevelIds.includes(field)) { + return true; + } + return false; } function getUniqueConstraints(model: DataModel) { const constraints: Array<{ name: string; fields: string[] }> = []; - for (const attr of model.attributes.filter((attr) => attr.decl.ref?.name === '@@unique')) { + for (const attr of model.attributes.filter( + (attr) => attr.decl.ref?.name === '@@unique' || attr.decl.ref?.name === '@@id' + )) { const argsMap = getAttributeArgs(attr); if (argsMap.fields) { const fieldNames = (argsMap.fields as ArrayExpr).items.map( diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 0be856429..cf1becad0 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -173,18 +173,18 @@ export default class PrismaSchemaGenerator { this.generateModelField(model, field); } + // add an "zenstack_guard" field for dealing with pure auth() related conditions + model.addField(GUARD_FIELD_NAME, 'Boolean', [ + new PrismaFieldAttribute('@default', [ + new PrismaAttributeArg(undefined, new PrismaAttributeArgValue('Boolean', true)), + ]), + ]); + const { allowAll, denyAll, hasFieldValidation } = analyzePolicies(decl); if ((!allowAll && !denyAll) || hasFieldValidation) { // generate auxiliary fields for policy check - // add an "zenstack_guard" field for dealing with pure auth() related conditions - model.addField(GUARD_FIELD_NAME, 'Boolean', [ - new PrismaFieldAttribute('@default', [ - new PrismaAttributeArg(undefined, new PrismaAttributeArgValue('Boolean', true)), - ]), - ]); - // add an "zenstack_transaction" field for tracking records created/updated with nested writes model.addField(TRANSACTION_FIELD_NAME, 'String?'); diff --git a/packages/schema/src/res/starter.zmodel b/packages/schema/src/res/starter.zmodel index 6cfee5b0f..75207fc1e 100644 --- a/packages/schema/src/res/starter.zmodel +++ b/packages/schema/src/res/starter.zmodel @@ -10,6 +10,10 @@ datasource db { url = 'file:./todo.db' } +generator client { + provider = "prisma-client-js" +} + /* * User model */ diff --git a/packages/schema/tests/schema/validation/datamodel-validation.test.ts b/packages/schema/tests/schema/validation/datamodel-validation.test.ts index fbe7ca4e9..067f5b341 100644 --- a/packages/schema/tests/schema/validation/datamodel-validation.test.ts +++ b/packages/schema/tests/schema/validation/datamodel-validation.test.ts @@ -101,7 +101,7 @@ describe('Data Model Validation Tests', () => { @@allow('all', x > 0) } `) - ).toContain(`Model must include a field with @id attribute`); + ).toContain(`Model must include a field with @id attribute or a model-level @@id attribute`); expect( await loadModelWithError(` @@ -111,7 +111,7 @@ describe('Data Model Validation Tests', () => { @@deny('all', x <= 0) } `) - ).toContain(`Model must include a field with @id attribute`); + ).toContain(`Model must include a field with @id attribute or a model-level @@id attribute`); expect( await loadModelWithError(` @@ -120,7 +120,7 @@ describe('Data Model Validation Tests', () => { x Int @gt(0) } `) - ).toContain(`Model must include a field with @id attribute`); + ).toContain(`Model must include a field with @id attribute or a model-level @@id attribute`); expect( await loadModelWithError(` @@ -132,6 +132,17 @@ describe('Data Model Validation Tests', () => { `) ).toContain(`Model can include at most one field with @id attribute`); + expect( + await loadModelWithError(` + ${prelude} + model M { + x Int @id + y Int + @@id([x, y]) + } + `) + ).toContain(`Model cannot have both field-level @id and model-level @@id attributes`); + expect( await loadModelWithError(` ${prelude} @@ -141,6 +152,16 @@ describe('Data Model Validation Tests', () => { `) ).toContain(`Field with @id attribute must not be optional`); + expect( + await loadModelWithError(` + ${prelude} + model M { + x Int? + @@id([x]) + } + `) + ).toContain(`Field with @id attribute must not be optional`); + expect( await loadModelWithError(` ${prelude} @@ -150,6 +171,16 @@ describe('Data Model Validation Tests', () => { `) ).toContain(`Field with @id attribute must be of scalar type`); + expect( + await loadModelWithError(` + ${prelude} + model M { + x Int[] + @@id([x]) + } + `) + ).toContain(`Field with @id attribute must be of scalar type`); + expect( await loadModelWithError(` ${prelude} @@ -159,6 +190,16 @@ describe('Data Model Validation Tests', () => { `) ).toContain(`Field with @id attribute must be of scalar type`); + expect( + await loadModelWithError(` + ${prelude} + model M { + x Json + @@id([x]) + } + `) + ).toContain(`Field with @id attribute must be of scalar type`); + expect( await loadModelWithError(` ${prelude} @@ -170,6 +211,19 @@ describe('Data Model Validation Tests', () => { } `) ).toContain(`Field with @id attribute must be of scalar type`); + + expect( + await loadModelWithError(` + ${prelude} + model Id { + id String @id + } + model M { + myId Id + @@id([myId]) + } + `) + ).toContain(`Field with @id attribute must be of scalar type`); }); it('relation', async () => { @@ -318,7 +372,9 @@ describe('Data Model Validation Tests', () => { aId String } `) - ).toContain(`Field "aId" is part of a one-to-one relation and must be marked as @unique`); + ).toContain( + `Field "aId" is part of a one-to-one relation and must be marked as @unique or be part of a model-level @@unique attribute` + ); // missing @relation expect( diff --git a/packages/sdk/package.json b/packages/sdk/package.json index aedc1e296..7f4a728c6 100644 --- a/packages/sdk/package.json +++ b/packages/sdk/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/sdk", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "ZenStack plugin development SDK", "main": "index.js", "scripts": { diff --git a/packages/testtools/package.json b/packages/testtools/package.json index 5596475fd..3023611b0 100644 --- a/packages/testtools/package.json +++ b/packages/testtools/package.json @@ -1,6 +1,6 @@ { "name": "@zenstackhq/testtools", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "description": "ZenStack Test Tools", "main": "index.js", "publishConfig": { diff --git a/tests/integration/test-run/package-lock.json b/tests/integration/test-run/package-lock.json index d42fea66a..4c7ae0716 100644 --- a/tests/integration/test-run/package-lock.json +++ b/tests/integration/test-run/package-lock.json @@ -126,7 +126,7 @@ }, "../../../packages/runtime/dist": { "name": "@zenstackhq/runtime", - "version": "1.0.0-alpha.56", + "version": "1.0.0-alpha.57", "license": "MIT", "dependencies": { "@types/bcryptjs": "^2.4.2", @@ -156,7 +156,7 @@ }, "../../../packages/schema/dist": { "name": "zenstack", - "version": "1.0.8", + "version": "1.0.0-alpha.57", "hasInstallScript": true, "license": "MIT", "dependencies": { diff --git a/tests/integration/tests/with-policy/multi-id-fields.test.ts b/tests/integration/tests/with-policy/multi-id-fields.test.ts new file mode 100644 index 000000000..f9984f98f --- /dev/null +++ b/tests/integration/tests/with-policy/multi-id-fields.test.ts @@ -0,0 +1,71 @@ +import { loadSchema } from '@zenstackhq/testtools'; +import path from 'path'; + +describe('With Policy: multiple id fields', () => { + let origDir: string; + + beforeAll(async () => { + origDir = path.resolve('.'); + }); + + afterEach(() => { + process.chdir(origDir); + }); + + it('multi-id fields', async () => { + const { prisma, withPolicy } = await loadSchema( + ` + model A { + x String + y Int + value Int + b B? + @@id([x, y]) + + @@allow('read', true) + @@allow('create', value > 0) + } + + model B { + b1 String + b2 String + value Int + a A @relation(fields: [ax, ay], references: [x, y]) + ax String + ay Int + + @@allow('read', value > 2) + @@allow('create', value > 1) + + @@unique([ax, ay]) + @@id([b1, b2]) + } + ` + ); + + const db = withPolicy(); + + await expect(db.a.create({ data: { x: '1', y: 1, value: 0 } })).toBeRejectedByPolicy(); + await expect(db.a.create({ data: { x: '1', y: 2, value: 1 } })).toResolveTruthy(); + + await expect( + db.a.create({ data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 1 } } } }) + ).toBeRejectedByPolicy(); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '2', y: 1, value: 1, b: { create: { b1: '1', b2: '2', value: 2 } } }, + }) + ).toBeRejectedByPolicy(); + const r = await prisma.b.findUnique({ where: { b1_b2: { b1: '1', b2: '2' } } }); + expect(r.value).toBe(2); + + await expect( + db.a.create({ + include: { b: true }, + data: { x: '3', y: 1, value: 1, b: { create: { b1: '2', b2: '2', value: 3 } } }, + }) + ).toResolveTruthy(); + }); +});