diff --git a/packages/schema/package.json b/packages/schema/package.json index 32914eafa..68b51de1e 100644 --- a/packages/schema/package.json +++ b/packages/schema/package.json @@ -92,6 +92,7 @@ "colors": "1.4.0", "commander": "^8.3.0", "cuid": "^2.1.8", + "get-latest-version": "^5.0.1", "langium": "1.1.0", "mixpanel": "^0.17.0", "node-machine-id": "^1.1.12", @@ -108,7 +109,7 @@ "vscode-languageserver-textdocument": "^1.0.7", "vscode-uri": "^3.0.6", "zod": "^3.19.1", - "get-latest-version": "^5.0.1" + "zod-validation-error": "^0.2.1" }, "devDependencies": { "@types/async-exit-hook": "^2.0.0", diff --git a/packages/schema/src/cli/config.ts b/packages/schema/src/cli/config.ts new file mode 100644 index 000000000..a9ac8a2b3 --- /dev/null +++ b/packages/schema/src/cli/config.ts @@ -0,0 +1,40 @@ +import { GUARD_FIELD_NAME, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk'; +import fs from 'fs'; +import z from 'zod'; +import { fromZodError } from 'zod-validation-error'; +import { CliError } from './cli-error'; + +const schema = z + .object({ + guardFieldName: z.string().default(GUARD_FIELD_NAME), + transactionFieldName: z.string().default(TRANSACTION_FIELD_NAME), + }) + .strict(); + +export type ConfigType = z.infer; + +export let config: ConfigType = schema.parse({}); + +/** + * Loads and validates CLI configuration file. + * @returns + */ +export function loadConfig(filename: string) { + if (!fs.existsSync(filename)) { + return; + } + + let content: unknown; + try { + content = JSON.parse(fs.readFileSync(filename, 'utf-8')); + } catch { + throw new CliError(`Config is not a valid JSON file: ${filename}`); + } + + const parsed = schema.safeParse(content); + if (!parsed.success) { + throw new CliError(`Config file ${filename} is not valid: ${fromZodError(parsed.error)}`); + } + + config = parsed.data; +} diff --git a/packages/schema/src/cli/index.ts b/packages/schema/src/cli/index.ts index db5ee0939..f9f402d16 100644 --- a/packages/schema/src/cli/index.ts +++ b/packages/schema/src/cli/index.ts @@ -2,16 +2,20 @@ import { ZModelLanguageMetaData } from '@zenstackhq/language/module'; import colors from 'colors'; import { Command, Option } from 'commander'; +import fs from 'fs'; import * as semver from 'semver'; import telemetry from '../telemetry'; import { PackageManagers } from '../utils/pkg-utils'; import { getVersion } from '../utils/version-utils'; import { CliError } from './cli-error'; import { dumpInfo, initProject, runPlugins } from './cli-util'; +import { loadConfig } from './config'; // required minimal version of Prisma export const requiredPrismaVersion = '4.0.0'; +const DEFAULT_CONFIG_FILE = 'zenstack.config.json'; + export const initAction = async ( projectPath: string, options: { @@ -97,6 +101,8 @@ export function createProgram() { './schema.zmodel' ); + const configOption = new Option('-c, --config [file]', 'config file'); + const pmOption = new Option('-p, --package-manager ', 'package manager to use').choices([ 'npm', 'yarn', @@ -114,6 +120,7 @@ export function createProgram() { program .command('init') .description('Initialize an existing project for ZenStack.') + .addOption(configOption) .addOption(pmOption) .addOption(new Option('--prisma ', 'location of Prisma schema file to bootstrap from')) .addOption(new Option('--tag [tag]', 'the NPM package tag to use when installing dependencies')) @@ -124,9 +131,27 @@ export function createProgram() { .command('generate') .description('Run code generation.') .addOption(schemaOption) + .addOption(configOption) .addOption(pmOption) .addOption(noDependencyCheck) .action(generateAction); + + // make sure config is loaded before actions run + program.hook('preAction', async (_, actionCommand) => { + let configFile: string | undefined = actionCommand.opts().config; + if (!configFile && fs.existsSync(DEFAULT_CONFIG_FILE)) { + configFile = DEFAULT_CONFIG_FILE; + } + + if (configFile) { + if (fs.existsSync(configFile)) { + loadConfig(configFile); + } else { + throw new CliError(`Config file ${configFile} not found`); + } + } + }); + return program; } diff --git a/packages/schema/src/cli/plugin-runner.ts b/packages/schema/src/cli/plugin-runner.ts index 2800855dd..3cc375216 100644 --- a/packages/schema/src/cli/plugin-runner.ts +++ b/packages/schema/src/cli/plugin-runner.ts @@ -1,5 +1,5 @@ /* eslint-disable @typescript-eslint/no-var-requires */ -import { DMMF } from '@prisma/generator-helper'; +import type { DMMF } from '@prisma/generator-helper'; import { getDMMF } from '@prisma/internals'; import { isPlugin, Plugin } from '@zenstackhq/language/ast'; import { getLiteral, getLiteralArray, PluginError, PluginFunction, PluginOptions } from '@zenstackhq/sdk'; @@ -8,7 +8,8 @@ import fs from 'fs'; import ora from 'ora'; import path from 'path'; import telemetry from '../telemetry'; -import { Context } from '../types'; +import type { Context } from '../types'; +import { config } from './config'; /** * ZenStack code generator @@ -133,7 +134,7 @@ export class PluginRunner { plugin: name, }, async () => { - let result = run(context.schema, options, dmmf); + let result = run(context.schema, options, dmmf, config); if (result instanceof Promise) { result = await result; } diff --git a/packages/schema/src/plugins/prisma/index.ts b/packages/schema/src/plugins/prisma/index.ts index c9bb7ac08..6ffc0a6a8 100644 --- a/packages/schema/src/plugins/prisma/index.ts +++ b/packages/schema/src/plugins/prisma/index.ts @@ -1,9 +1,15 @@ +import type { DMMF } from '@prisma/generator-helper'; import { Model } from '@zenstackhq/language/ast'; import { PluginOptions } from '@zenstackhq/sdk'; import PrismaSchemaGenerator from './schema-generator'; export const name = 'Prisma'; -export default async function run(model: Model, options: PluginOptions) { - return new PrismaSchemaGenerator().generate(model, options); +export default async function run( + model: Model, + options: PluginOptions, + _dmmf?: DMMF.Document, + config?: Record +) { + return new PrismaSchemaGenerator().generate(model, options, config); } diff --git a/packages/schema/src/plugins/prisma/schema-generator.ts b/packages/schema/src/plugins/prisma/schema-generator.ts index 4066270e3..6d2c32e9b 100644 --- a/packages/schema/src/plugins/prisma/schema-generator.ts +++ b/packages/schema/src/plugins/prisma/schema-generator.ts @@ -69,7 +69,7 @@ export default class PrismaSchemaGenerator { `; - async generate(model: Model, options: PluginOptions) { + async generate(model: Model, options: PluginOptions, config?: Record) { const prisma = new PrismaModel(); for (const decl of model.declarations) { @@ -83,7 +83,7 @@ export default class PrismaSchemaGenerator { break; case DataModel: - this.generateModel(prisma, decl as DataModel); + this.generateModel(prisma, decl as DataModel, config); break; case GeneratorDecl: @@ -191,26 +191,33 @@ export default class PrismaSchemaGenerator { ); } - private generateModel(prisma: PrismaModel, decl: DataModel) { + private generateModel(prisma: PrismaModel, decl: DataModel, config?: Record) { const model = prisma.addModel(decl.name); for (const field of decl.fields) { this.generateModelField(model, field); } // add an "zenstack_guard" field for dealing with boolean conditions - model.addField(GUARD_FIELD_NAME, 'Boolean', [ + const guardField = model.addField(GUARD_FIELD_NAME, 'Boolean', [ new PrismaFieldAttribute('@default', [ new PrismaAttributeArg(undefined, new PrismaAttributeArgValue('Boolean', true)), ]), ]); + if (config?.guardFieldName && config?.guardFieldName !== GUARD_FIELD_NAME) { + // generate a @map to rename field in the database + guardField.addAttribute('@map', [ + new PrismaAttributeArg(undefined, new PrismaAttributeArgValue('String', config.guardFieldName)), + ]); + } + const { allowAll, denyAll, hasFieldValidation } = analyzePolicies(decl); if ((!allowAll && !denyAll) || hasFieldValidation) { // generate auxiliary fields for policy check // add an "zenstack_transaction" field for tracking records created/updated with nested writes - model.addField(TRANSACTION_FIELD_NAME, 'String?'); + const transactionField = model.addField(TRANSACTION_FIELD_NAME, 'String?'); // create an index for "zenstack_transaction" field model.addAttribute('@@index', [ @@ -221,6 +228,16 @@ export default class PrismaSchemaGenerator { ]) ), ]); + + if (config?.transactionFieldName && config?.transactionFieldName !== TRANSACTION_FIELD_NAME) { + // generate a @map to rename field in the database + transactionField.addAttribute('@map', [ + new PrismaAttributeArg( + undefined, + new PrismaAttributeArgValue('String', config.transactionFieldName) + ), + ]); + } } for (const attr of decl.attributes.filter((attr) => this.isPrismaAttribute(attr))) { diff --git a/packages/schema/tests/cli/cli.test.ts b/packages/schema/tests/cli/command.test.ts similarity index 88% rename from packages/schema/tests/cli/cli.test.ts rename to packages/schema/tests/cli/command.test.ts index f165c0803..73f735976 100644 --- a/packages/schema/tests/cli/cli.test.ts +++ b/packages/schema/tests/cli/command.test.ts @@ -8,7 +8,7 @@ import * as tmp from 'tmp'; import { createProgram } from '../../src/cli'; import { execSync } from '../../src/utils/exec-utils'; -describe('CLI Tests', () => { +describe('CLI Command Tests', () => { let projDir: string; let origDir: string; @@ -37,7 +37,7 @@ describe('CLI Tests', () => { createNpmrc(); const program = createProgram(); - program.parse(['init', '--tag', 'latest'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toEqual(fs.readFileSync('prisma/schema.prisma', 'utf-8')); @@ -53,7 +53,7 @@ describe('CLI Tests', () => { createNpmrc(); const program = createProgram(); - program.parse(['init', '--tag', 'latest'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toEqual(fs.readFileSync('prisma/schema.prisma', 'utf-8')); @@ -69,7 +69,7 @@ describe('CLI Tests', () => { createNpmrc(); const program = createProgram(); - program.parse(['init', '--tag', 'latest'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toEqual(fs.readFileSync('prisma/schema.prisma', 'utf-8')); @@ -86,7 +86,7 @@ describe('CLI Tests', () => { fs.renameSync('prisma/schema.prisma', 'prisma/my.prisma'); const program = createProgram(); - program.parse(['init', '--tag', 'latest', '--prisma', 'prisma/my.prisma'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest', '--prisma', 'prisma/my.prisma'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toEqual(fs.readFileSync('prisma/my.prisma', 'utf-8')); }); @@ -96,7 +96,7 @@ describe('CLI Tests', () => { fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); createNpmrc(); const program = createProgram(); - program.parse(['init', '--tag', 'latest'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toBeTruthy(); }); @@ -111,7 +111,7 @@ describe('CLI Tests', () => { fs.writeFileSync('schema.zmodel', origZModelContent); createNpmrc(); const program = createProgram(); - program.parse(['init', '--tag', 'latest'], { from: 'user' }); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); expect(fs.readFileSync('schema.zmodel', 'utf-8')).toEqual(origZModelContent); }); }); diff --git a/packages/schema/tests/cli/config.test.ts b/packages/schema/tests/cli/config.test.ts new file mode 100644 index 000000000..9259d1c42 --- /dev/null +++ b/packages/schema/tests/cli/config.test.ts @@ -0,0 +1,98 @@ +/* eslint-disable @typescript-eslint/no-var-requires */ +/// + +import * as fs from 'fs'; +import * as tmp from 'tmp'; +import { createProgram } from '../../src/cli'; +import { CliError } from '../../src/cli/cli-error'; +import { config } from '../../src/cli/config'; +import { GUARD_FIELD_NAME, TRANSACTION_FIELD_NAME } from '@zenstackhq/sdk'; + +describe('CLI Config Tests', () => { + let projDir: string; + let origDir: string; + + beforeEach(() => { + origDir = process.cwd(); + const r = tmp.dirSync(); + projDir = r.name; + console.log(`Project dir: ${projDir}`); + process.chdir(projDir); + }); + + afterEach(() => { + fs.rmSync(projDir, { recursive: true, force: true }); + process.chdir(origDir); + }); + + it('invalid default config', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + fs.writeFileSync('zenstack.config.json', JSON.stringify({ abc: 'def' })); + + const program = createProgram(); + await expect(program.parseAsync(['init', '--tag', 'latest'], { from: 'user' })).rejects.toBeInstanceOf( + CliError + ); + }); + + it('valid default config empty', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + fs.writeFileSync('zenstack.config.json', JSON.stringify({})); + + const program = createProgram(); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); + + // custom config + expect(config.guardFieldName).toBe(GUARD_FIELD_NAME); + + // default value + expect(config.transactionFieldName).toBe(TRANSACTION_FIELD_NAME); + }); + + it('valid default config non-empty', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + fs.writeFileSync( + 'zenstack.config.json', + JSON.stringify({ guardFieldName: 'myGuardField', transactionFieldName: 'myTransactionField' }) + ); + + const program = createProgram(); + await program.parseAsync(['init', '--tag', 'latest'], { from: 'user' }); + + // custom config + expect(config.guardFieldName).toBe('myGuardField'); + + // default value + expect(config.transactionFieldName).toBe('myTransactionField'); + }); + + it('config not found', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + const program = createProgram(); + await expect( + program.parseAsync(['init', '--tag', 'latest', '--config', 'my.config.json'], { from: 'user' }) + ).rejects.toBeInstanceOf(CliError); + }); + + it('valid custom config file', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + fs.writeFileSync('my.config.json', JSON.stringify({ guardFieldName: 'myGuardField' })); + const program = createProgram(); + await program.parseAsync(['init', '--tag', 'latest', '--config', 'my.config.json'], { from: 'user' }); + + // custom config + expect(config.guardFieldName).toBe('myGuardField'); + + // default value + expect(config.transactionFieldName).toBe(TRANSACTION_FIELD_NAME); + }); + + it('invalid custom config file', async () => { + fs.writeFileSync('package.json', JSON.stringify({ name: 'my app', version: '1.0.0' })); + fs.writeFileSync('my.config.json', JSON.stringify({ abc: 'def' })); + const program = createProgram(); + await expect( + program.parseAsync(['init', '--tag', 'latest', '--config', 'my.config.json'], { from: 'user' }) + ).rejects.toBeInstanceOf(CliError); + }); +}); diff --git a/packages/schema/tests/generator/prisma-generator.test.ts b/packages/schema/tests/generator/prisma-generator.test.ts index 263c335e6..d763376b1 100644 --- a/packages/schema/tests/generator/prisma-generator.test.ts +++ b/packages/schema/tests/generator/prisma-generator.test.ts @@ -191,4 +191,35 @@ describe('Prisma generator test', () => { expect(content).toContain('@@schema("base")'); expect(content).toContain('schemas = ["base","transactional"]'); }); + + it('custom aux field names', async () => { + const model = await loadModel(` + datasource db { + provider = 'postgresql' + url = env('URL') + } + + model Foo { + id String @id + value Int + @@allow('create', value > 0) + } + `); + + const { name } = tmp.fileSync({ postfix: '.prisma' }); + await new PrismaSchemaGenerator().generate( + model, + { + provider: '@core/prisma', + schemaPath: 'schema.zmodel', + output: name, + }, + { guardFieldName: 'myGuardField', transactionFieldName: 'myTransactionField' } + ); + + const content = fs.readFileSync(name, 'utf-8'); + await getDMMF({ datamodel: content }); + expect(content).toContain('@map("myGuardField")'); + expect(content).toContain('@map("myTransactionField")'); + }); }); diff --git a/packages/sdk/src/types.ts b/packages/sdk/src/types.ts index bf9d2c3c0..bb967d234 100644 --- a/packages/sdk/src/types.ts +++ b/packages/sdk/src/types.ts @@ -17,7 +17,8 @@ export type PluginOptions = { provider?: string; schemaPath: string } & Record ) => Promise | string[] | Promise | void; /** diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 503065162..6a548c055 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -415,6 +415,9 @@ importers: zod: specifier: ^3.19.1 version: 3.19.1 + zod-validation-error: + specifier: ^0.2.1 + version: 0.2.1(zod@3.19.1) devDependencies: '@types/async-exit-hook': specifier: ^2.0.0