Skip to content

Commit d20cbd4

Browse files
committed
feat(runtime): inject enhanced client or tx context so it can be retrieved in extensions
1 parent 82b8d25 commit d20cbd4

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

packages/runtime/src/enhancements/node/proxy.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ export function makeProxy<T extends PrismaProxyHandler>(
289289
return propVal;
290290
}
291291

292-
return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer);
292+
return createHandlerProxy(makeHandler(target, prop), propVal, prop, target, errorTransformer);
293293
},
294294
});
295295

@@ -303,10 +303,15 @@ function createHandlerProxy<T extends PrismaProxyHandler>(
303303
handler: T,
304304
origTarget: any,
305305
model: string,
306+
dbOrTx: any,
306307
errorTransformer?: ErrorTransformer
307308
): T {
308309
return new Proxy(handler, {
309310
get(target, propKey) {
311+
if (propKey === '$zenstack_parent') {
312+
return dbOrTx;
313+
}
314+
310315
const prop = target[propKey as keyof T];
311316
if (typeof prop !== 'function') {
312317
// the proxy handler doesn't have this method, fall back to the original target
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('Proxy Extension Context', () => {
4+
it('works', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model Counter {
8+
model String @unique
9+
value Int
10+
11+
@@allow('all', true)
12+
}
13+
14+
model Address {
15+
id String @id @default(cuid())
16+
city String
17+
18+
@@allow('all', true)
19+
}
20+
`
21+
);
22+
23+
const db = enhance();
24+
const dbExtended = db.$extends({
25+
model: {
26+
$allModels: {
27+
async createWithCounter(this: any, args: any) {
28+
const modelName = this.$name;
29+
const dbOrTx = this.$zenstack_parent;
30+
31+
const fn = async (tx: any) => {
32+
const counter = await tx.counter.findUnique({
33+
where: { model: modelName },
34+
});
35+
36+
await tx.counter.upsert({
37+
where: { model: modelName },
38+
update: { value: (counter?.value ?? 0) + 1 },
39+
create: { model: modelName, value: 1 },
40+
});
41+
42+
return tx[modelName].create(args);
43+
};
44+
45+
if (dbOrTx['$transaction']) {
46+
// not running in a transaction, so we need to create a new transaction
47+
return dbOrTx.$transaction(fn);
48+
}
49+
50+
return fn(dbOrTx);
51+
},
52+
},
53+
},
54+
});
55+
56+
const cities = [
57+
'Vienna',
58+
'Paris',
59+
'London',
60+
'Berlin',
61+
'New York',
62+
'Tokyo',
63+
'Sydney',
64+
'Seoul',
65+
'Mumbai',
66+
'Delhi',
67+
'Shanghai',
68+
];
69+
70+
await Promise.all([
71+
...cities.map((city) => dbExtended.address.createWithCounter({ data: { city } })),
72+
...cities.map((city) =>
73+
dbExtended.$transaction((tx: any) => tx.address.createWithCounter({ data: { city: `${city}$tx` } }))
74+
),
75+
]);
76+
77+
// expecting object
78+
await expect(dbExtended.counter.findUniqueOrThrow({ where: { model: 'Address' } })).resolves.toMatchObject({
79+
model: 'Address',
80+
value: cities.length * 2,
81+
});
82+
});
83+
});

0 commit comments

Comments
 (0)