Skip to content

Commit 7c243dd

Browse files
authored
feat(runtime): inject enhanced client or tx context so it can be retrieved in extensions (#2018)
1 parent 1c5900e commit 7c243dd

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

packages/runtime/src/enhancements/node/proxy.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ export function makeProxy<T extends PrismaProxyHandler>(
284284
return propVal;
285285
}
286286

287-
return createHandlerProxy(makeHandler(target, prop), propVal, prop, errorTransformer);
287+
return createHandlerProxy(makeHandler(target, prop), propVal, prop, proxy, errorTransformer);
288288
},
289289
});
290290

@@ -298,10 +298,15 @@ function createHandlerProxy<T extends PrismaProxyHandler>(
298298
handler: T,
299299
origTarget: any,
300300
model: string,
301+
dbOrTx: any,
301302
errorTransformer?: ErrorTransformer
302303
): T {
303304
return new Proxy(handler, {
304305
get(target, propKey) {
306+
if (propKey === '$parent') {
307+
return dbOrTx;
308+
}
309+
305310
const prop = target[propKey as keyof T];
306311
if (typeof prop !== 'function') {
307312
// the proxy handler doesn't have this method, fall back to the original target
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import { loadSchema } from '@zenstackhq/testtools';
2+
3+
describe('Proxy Extension Context', () => {
4+
it('works', async () => {
5+
const { enhance } = await loadSchema(
6+
`
7+
model Counter {
8+
model String @unique
9+
value Int
10+
11+
@@allow('all', true)
12+
}
13+
14+
model Address {
15+
id String @id @default(cuid())
16+
city String
17+
18+
@@allow('all', true)
19+
}
20+
`
21+
);
22+
23+
const db = enhance();
24+
const dbExtended = db.$extends({
25+
client: {
26+
$one() {
27+
return 1;
28+
}
29+
},
30+
model: {
31+
$allModels: {
32+
async createWithCounter(this: any, args: any) {
33+
const modelName = this.$name;
34+
const dbOrTx = this.$parent;
35+
36+
// prisma exposes some internal properties, makes sure these are still preserved
37+
expect(dbOrTx._engine).toBeDefined();
38+
39+
const fn = async (tx: any) => {
40+
const counter = await tx.counter.findUnique({
41+
where: { model: modelName },
42+
});
43+
44+
await tx.counter.upsert({
45+
where: { model: modelName },
46+
update: { value: (counter?.value ?? 0) + tx.$one() },
47+
create: { model: modelName, value: tx.$one() },
48+
});
49+
50+
return tx[modelName].create(args);
51+
};
52+
53+
if (dbOrTx['$transaction']) {
54+
// not running in a transaction, so we need to create a new transaction
55+
return dbOrTx.$transaction(fn);
56+
}
57+
58+
return fn(dbOrTx);
59+
},
60+
},
61+
},
62+
});
63+
64+
const cities = ['Vienna', 'New York', 'Delhi'];
65+
66+
await Promise.all([
67+
...cities.map((city) => dbExtended.address.createWithCounter({ data: { city } })),
68+
...cities.map((city) =>
69+
dbExtended.$transaction((tx: any) => tx.address.createWithCounter({ data: { city: `${city}$tx` } }))
70+
),
71+
]);
72+
73+
await expect(dbExtended.counter.findUniqueOrThrow({ where: { model: 'Address' } })).resolves.toMatchObject({
74+
model: 'Address',
75+
value: cities.length * 2,
76+
});
77+
});
78+
});

0 commit comments

Comments
 (0)