Skip to content

Commit 1ab82df

Browse files
committed
fix dtype
1 parent 17884bc commit 1ab82df

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

jetstream_pt/layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward(self, inputs):
142142
self.weight,
143143
(((2,), (1)), ((), ())),
144144
None,
145-
torch.int32,
145+
jnp.int32.dtype,
146146
)
147147
result = result * self.weight_scaler
148148
if self.quantize_activation:

0 commit comments

Comments
 (0)