Skip to content

Commit 17884bc

Browse files
committed
pyink
1 parent c86925e commit 17884bc

File tree

5 files changed

+85
-46
lines changed

5 files changed

+85
-46
lines changed

jetstream_pt/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@
4242

4343
# Quantization related flags
4444
flags.DEFINE_bool("quantize_weights", False, "weight quantization")
45-
flags.DEFINE_bool("quantize_activation", False, "Quantize Q,K,V projection and FeedForward activation.")
45+
flags.DEFINE_bool(
46+
"quantize_activation",
47+
False,
48+
"Quantize Q,K,V projection and FeedForward activation.",
49+
)
4650
flags.DEFINE_string(
4751
"quantize_type", "int8_per_channel", "Type of quantization."
4852
)
@@ -91,9 +95,9 @@ def create_quantization_config_from_flags():
9195
config.enable_weight_quantization = True
9296
config.num_bits_weight = 8 if "int8" in quantize_type else 4
9397
config.is_blockwise_weight = "blockwise" in quantize_type
94-
98+
9599
config.enable_activation_quantization = FLAGS.quantize_activation
96-
100+
97101
config.enable_kv_quantization = FLAGS.quantize_kv_cache
98102
return config
99103

jetstream_pt/layers.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,10 @@ def __init__(
102102

103103
# Number of bits of weight tensor
104104
self.n_bit = quant_config.num_bits_weight
105-
105+
106106
# Quantize activation
107107
self.quantize_activation = quant_config.enable_activation_quantization
108-
108+
109109
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
110110
self.run_fake_quantize = False
111111

@@ -136,8 +136,14 @@ def forward(self, inputs):
136136
if not self.quantize_activation:
137137
result = F.linear(inputs, self.weight)
138138
else:
139-
result = torchjax.call_jax(jax.lax.dot_general, inputs, self.weight,
140-
(((2,),(1)),((),())), None, torch.int32)
139+
result = torchjax.call_jax(
140+
jax.lax.dot_general,
141+
inputs,
142+
self.weight,
143+
(((2,), (1)), ((), ())),
144+
None,
145+
torch.int32,
146+
)
141147
result = result * self.weight_scaler
142148
if self.quantize_activation:
143149
result = result * act_s
@@ -182,15 +188,21 @@ def __init__(
182188
self.block_size = quant_config.block_size_weight
183189
n_blocks = in_features // self.block_size
184190

185-
assert not quant_config.enable_activation_quantization, "Activation quantization not supported for blockwise quantized matmul."
186-
191+
assert (
192+
not quant_config.enable_activation_quantization
193+
), "Activation quantization not supported for blockwise quantized matmul."
194+
187195
if self.use_dot_general:
188196
weight = torch.ones(
189-
(n_blocks, out_features, self.block_size), dtype=torch.int8, device=device
197+
(n_blocks, out_features, self.block_size),
198+
dtype=torch.int8,
199+
device=device,
190200
)
191201
else:
192202
weight = torch.ones(
193-
(n_blocks, self.block_size, out_features), dtype=torch.int8, device=device
203+
(n_blocks, self.block_size, out_features),
204+
dtype=torch.int8,
205+
device=device,
194206
)
195207
self.register_buffer("weight", weight)
196208

@@ -209,7 +221,7 @@ def __init__(
209221
self.register_buffer("zero_point", None)
210222

211223
self.n_bit = quant_config.num_bits_weight
212-
224+
213225
# Quantize activation
214226
self.quantize_activation = quant_config.enable_activation_quantization
215227

@@ -240,15 +252,23 @@ def quantize_weight_from_nn_linear(self, weight):
240252
def forward(self, inputs):
241253
if not self.run_fake_quantize:
242254
if self.use_dot_general or self.flatten:
243-
assert self.zero_point is None, "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
244-
blockwise_matmul_kernel = blockwise_jax_kernel if not self.use_dot_general and not self.flatten else blockwise_jax_kernel_dot_general if self.use_dot_general else blockwise_jax_kernel_einsum_flatten
255+
assert (
256+
self.zero_point is None
257+
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
258+
blockwise_matmul_kernel = (
259+
blockwise_jax_kernel
260+
if not self.use_dot_general and not self.flatten
261+
else blockwise_jax_kernel_dot_general
262+
if self.use_dot_general
263+
else blockwise_jax_kernel_einsum_flatten
264+
)
245265
result = torchjax.call_jax(
246-
blockwise_matmul_kernel,
247-
inputs,
248-
self.weight,
249-
self.weight_scaler,
250-
self.zero_point,
251-
)
266+
blockwise_matmul_kernel,
267+
inputs,
268+
self.weight,
269+
self.weight_scaler,
270+
self.zero_point,
271+
)
252272
return result
253273
else:
254274
# Fake quantization, debugging purpose.

jetstream_pt/quantize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414

1515
from typing import Tuple, Union
1616

17-
import torch
17+
import jax
1818
import jax.numpy as jnp
19+
import torch
1920

2021
EPS = 1e-5
2122

@@ -116,9 +117,7 @@ def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point):
116117
return out
117118

118119

119-
def blockwise_jax_kernel_dot_general(
120-
inputs, weight, weight_scaler, zero_point
121-
):
120+
def blockwise_jax_kernel_dot_general(inputs, weight, weight_scaler, zero_point):
122121
"""Blockwise Matmul kernel impl in JAX using dot general"""
123122
inputs_shape = inputs.shape
124123
block_size = weight.shape[2]

jetstream_pt/third_party/gemma/model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,25 @@ def __init__(
240240
linear_kwargs = {"quant_config": env.quant_config}
241241

242242
self.gate_proj = Linear(
243-
hidden_size, intermediate_size, bias=False, device=device, **linear_kwargs,
243+
hidden_size,
244+
intermediate_size,
245+
bias=False,
246+
device=device,
247+
**linear_kwargs,
244248
)
245249
self.up_proj = Linear(
246-
hidden_size, intermediate_size, bias=False, device=device, **linear_kwargs,
250+
hidden_size,
251+
intermediate_size,
252+
bias=False,
253+
device=device,
254+
**linear_kwargs,
247255
)
248256
self.down_proj = Linear(
249-
intermediate_size, hidden_size, bias=False, device=device, **linear_kwargs,
257+
intermediate_size,
258+
hidden_size,
259+
bias=False,
260+
device=device,
261+
**linear_kwargs,
250262
)
251263

252264
def forward(self, x):

tests/test_quantization.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ def _calc_cosine_dist(self, x, y):
4848
return (torch.dot(x, y) / (x.norm() * y.norm())).item()
4949

5050
def _nn_linear_run_and_compare(
51-
self,
52-
nn_linear,
53-
qlinear_layer,
54-
arg,
55-
):
56-
torch_result = nn_linear(arg)
57-
qlinear_layer.quantize_weight_from_nn_linear(nn_linear.weight)
58-
result = helpers.call_xla_model(
59-
qlinear_layer, qlinear_layer.state_dict(), arg
60-
)
61-
diff = result - torch_result
62-
return result, torch_result, diff
51+
self,
52+
nn_linear,
53+
qlinear_layer,
54+
arg,
55+
):
56+
torch_result = nn_linear(arg)
57+
qlinear_layer.quantize_weight_from_nn_linear(nn_linear.weight)
58+
result = helpers.call_xla_model(
59+
qlinear_layer, qlinear_layer.state_dict(), arg
60+
)
61+
diff = result - torch_result
62+
return result, torch_result, diff
6363

6464
def _print_diff(self, w, w_dq):
6565
print("Print diff:")
@@ -195,7 +195,9 @@ def test_weight_only_quant(self):
195195
block_q_linear = WeightOnlyBlockwiseQuantizedLinear(
196196
in_features, out_features
197197
)
198-
res, torch_res, block_diff = self._nn_linear_run_and_compare(nn_linear, block_q_linear, arg)
198+
res, torch_res, block_diff = self._nn_linear_run_and_compare(
199+
nn_linear, block_q_linear, arg
200+
)
199201
# self.assertTrue(torch.allclose(res, torch_res, atol=1.5))
200202
# Block quant is more accurate than per_channel quant.
201203
self.assertLess(block_diff.norm(), per_channel_diff.norm())
@@ -210,7 +212,9 @@ def test_weight_only_quant(self):
210212
)
211213
# self._print_diff(res, torch_res)
212214
self.assertTrue(torch.allclose(res, torch_res, atol=2))
213-
quant_config = QuantizationConfig(is_symmetric_weight=False, is_blockwise_weight=True)
215+
quant_config = QuantizationConfig(
216+
is_symmetric_weight=False, is_blockwise_weight=True
217+
)
214218
block_q_linear = WeightOnlyBlockwiseQuantizedLinear(
215219
in_features, out_features, quant_config=quant_config
216220
)
@@ -273,28 +277,28 @@ def shard_and_lower(f, layer, state_dict_jax, input, shardings):
273277
opt_hlo = shard_and_lower(f, layer, state_dict_jax, input, sharding)
274278
self.assertFalse("all-to-all" in opt_hlo)
275279
self.assertFalse("all-reduce-scatter" in opt_hlo)
276-
280+
277281
def test_activation_quant_per_channel(self):
278282

279283
out_features = 8
280284
in_features = 4
281285
block_size = 128
282-
286+
283287
arg = torch.randn(2, 1, in_features).to(torch.bfloat16)
284288
nn_linear = torch.nn.Linear(
285289
in_features, out_features, bias=False, dtype=torch.bfloat16
286290
)
287291
quant_config = QuantizationConfig(
288-
enable_weight_quantization=True,
289-
enable_activation_quantization=True,
292+
enable_weight_quantization=True,
293+
enable_activation_quantization=True,
290294
)
291295
per_channel_q_linear = WeightOnlyPerChannelQuantizedLinear(
292296
in_features, out_features, quant_config=quant_config
293297
)
294298
res, torch_res, _ = self._nn_linear_run_and_compare(
295299
nn_linear, per_channel_q_linear, arg
296300
)
297-
self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999)
301+
self.assertGreater(self._calc_cosine_dist(res, torch_res), 0.9999)
298302

299303

300304
if __name__ == "__main__":

0 commit comments

Comments
 (0)