Skip to content

Commit c86925e

Browse files
committed
add activation quant support
1 parent bf73f02 commit c86925e

File tree

8 files changed

+210
-146
lines changed

8 files changed

+210
-146
lines changed

jetstream_pt/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
# Quantization related flags
4444
flags.DEFINE_bool("quantize_weights", False, "weight quantization")
45+
flags.DEFINE_bool("quantize_activation", False, "Quantize Q,K,V projection and FeedForward activation.")
4546
flags.DEFINE_string(
4647
"quantize_type", "int8_per_channel", "Type of quantization."
4748
)
@@ -90,6 +91,9 @@ def create_quantization_config_from_flags():
9091
config.enable_weight_quantization = True
9192
config.num_bits_weight = 8 if "int8" in quantize_type else 4
9293
config.is_blockwise_weight = "blockwise" in quantize_type
94+
95+
config.enable_activation_quantization = FLAGS.quantize_activation
96+
9397
config.enable_kv_quantization = FLAGS.quantize_kv_cache
9498
return config
9599

jetstream_pt/environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class QuantizationConfig:
3232
enable_weight_quantization: bool = False
3333
num_bits_weight: int = 8
3434
is_blockwise_weight: bool = False
35+
block_size_weight: int = 128
36+
is_symmetric_weight: bool = True
37+
38+
enable_activation_quantization: bool = False
3539

3640
enable_kv_quantization: bool = False
3741

jetstream_pt/layers.py

Lines changed: 64 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import torch_xla2
2626
from jax import lax
2727
from jetstream_pt import torchjax
28+
from jetstream_pt.environment import QuantizationConfig
2829
from jetstream_pt.quantize import (
2930
dequantize_tensor,
3031
load_q_weight_helper,
3132
quantize_tensor,
33+
blockwise_jax_kernel,
34+
blockwise_jax_kernel_dot_general,
35+
blockwise_jax_kernel_einsum_flatten,
3236
)
3337
from torch import nn
3438
from . import attention_kernel as ak
@@ -68,8 +72,7 @@ def __init__(
6872
out_features,
6973
bias=False,
7074
device=None,
71-
is_symmetric=True,
72-
n_bit=8,
75+
quant_config=QuantizationConfig(),
7376
):
7477
super().__init__()
7578
self.in_features = in_features
@@ -85,8 +88,9 @@ def __init__(
8588
)
8689
self.register_buffer("weight_scaler", weight_scaler)
8790

88-
self.is_symmetric = is_symmetric
89-
if not is_symmetric:
91+
self.is_symmetric_weight = quant_config.is_symmetric_weight
92+
93+
if not self.is_symmetric_weight:
9094
zero_point = torch.ones(
9195
(out_features,), dtype=torch.bfloat16, device=device
9296
)
@@ -96,7 +100,12 @@ def __init__(
96100

97101
assert not bias, "Quantized Linear doesn't support bias."
98102

99-
self.n_bit = n_bit
103+
# Number of bits of weight tensor
104+
self.n_bit = quant_config.num_bits_weight
105+
106+
# Quantize activation
107+
self.quantize_activation = quant_config.enable_activation_quantization
108+
100109
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
101110
self.run_fake_quantize = False
102111

@@ -115,23 +124,31 @@ def quantize_weight_from_nn_linear(self, weight):
115124
self.in_features,
116125
), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})."
117126
w_q, scale, zp = quantize_tensor(
118-
weight, (1,), self.n_bit, self.is_symmetric, block_size=-1
127+
weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1
119128
)
120129
w_dq = dequantize_tensor(w_q, scale, zp)
121130
self._load_quantized_weights(w_q, scale, zp)
122131

123132
def forward(self, inputs):
124133
if not self.run_fake_quantize:
125-
if self.is_symmetric:
126-
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
134+
if self.quantize_activation:
135+
inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,))
136+
if not self.quantize_activation:
137+
result = F.linear(inputs, self.weight)
127138
else:
128-
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
139+
result = torchjax.call_jax(jax.lax.dot_general, inputs, self.weight,
140+
(((2,),(1)),((),())), None, torch.int32)
141+
result = result * self.weight_scaler
142+
if self.quantize_activation:
143+
result = result * act_s
144+
if not self.is_symmetric_weight:
129145
zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point)
130-
return out - zp_out
146+
result = result - zp_out
147+
return result
131148
else:
132149
# Fake quantization, debugging purpose.
133150
scaler = self.weight_scaler.unsqueeze(-1)
134-
if not self.is_symmetric:
151+
if not self.is_symmetric_weight:
135152
zero_point = self.zero_point.unsqueeze(-1) / scaler
136153
else:
137154
zero_point = None
@@ -149,32 +166,31 @@ def __init__(
149166
out_features,
150167
bias=False,
151168
device=None,
152-
is_symmetric=True,
153-
use_dot_general=False,
154-
block_size=128,
155-
n_bit=8,
169+
quant_config=QuantizationConfig(),
156170
):
157171
super().__init__()
158172
self.in_features = in_features
159173
self.out_features = out_features
160174

161175
# Use dot general instead of einsum
162176
# Use dot general is slow now.
163-
self.use_dot_general = use_dot_general
177+
self.use_dot_general = False
164178
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
165179
# Same perf as non flattened one now.
166180
self.flatten = False
167181

168-
self.block_size = block_size
169-
n_blocks = in_features // block_size
182+
self.block_size = quant_config.block_size_weight
183+
n_blocks = in_features // self.block_size
170184

185+
assert not quant_config.enable_activation_quantization, "Activation quantization not supported for blockwise quantized matmul."
186+
171187
if self.use_dot_general:
172188
weight = torch.ones(
173-
(n_blocks, out_features, block_size), dtype=torch.int8, device=device
189+
(n_blocks, out_features, self.block_size), dtype=torch.int8, device=device
174190
)
175191
else:
176192
weight = torch.ones(
177-
(n_blocks, block_size, out_features), dtype=torch.int8, device=device
193+
(n_blocks, self.block_size, out_features), dtype=torch.int8, device=device
178194
)
179195
self.register_buffer("weight", weight)
180196

@@ -183,16 +199,20 @@ def __init__(
183199
)
184200
self.register_buffer("weight_scaler", weight_scaler)
185201

186-
self.is_symmetric = is_symmetric
187-
if not self.is_symmetric:
202+
self.is_symmetric_weight = quant_config.is_symmetric_weight
203+
if not self.is_symmetric_weight:
188204
zero_point = torch.ones(
189205
(n_blocks, out_features), dtype=torch.bfloat16, device=device
190206
)
191207
self.register_buffer("zero_point", zero_point)
192208
else:
193209
self.register_buffer("zero_point", None)
194210

195-
self.n_bit = n_bit
211+
self.n_bit = quant_config.num_bits_weight
212+
213+
# Quantize activation
214+
self.quantize_activation = quant_config.enable_activation_quantization
215+
196216
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
197217
self.run_fake_quantize = False
198218

@@ -211,112 +231,30 @@ def quantize_weight_from_nn_linear(self, weight):
211231
self.in_features,
212232
), f"Unexpected weight shape ({self.out_features}, {self.in_features})."
213233
w_q, scale, zp = quantize_tensor(
214-
weight, (1,), self.n_bit, self.is_symmetric, self.block_size
234+
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size
215235
)
216236
w_dq = dequantize_tensor(w_q, scale, zp)
217237
print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq))
218-
# breakpoint()
219238
self._load_quantized_weights(w_q, scale, zp)
220239

221-
@staticmethod
222-
def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point):
223-
"""Blockwise Matmul kernel impl in JAX using einsum"""
224-
weight = weight.astype(jnp.int8)
225-
block_size = weight.shape[1]
226-
inputs_shape = inputs.shape
227-
inputs_new_shape = inputs_shape[:-1] + (
228-
inputs_shape[-1] // block_size,
229-
block_size,
230-
)
231-
inputs = inputs.reshape(inputs_new_shape)
232-
out = jnp.einsum("scz,bdsc->bdsz", weight, inputs)
233-
out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler)
234-
if zero_point is not None:
235-
zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point)
236-
out = out - zp_out
237-
return out
238-
239-
@staticmethod
240-
def blockwise_jax_kernel_dot_general(
241-
inputs, weight, weight_scaler, zero_point
242-
):
243-
"""Blockwise Matmul kernel impl in JAX using dot general"""
244-
inputs_shape = inputs.shape
245-
block_size = weight.shape[2]
246-
bs = inputs_shape[0]
247-
inputs_new_shape = inputs_shape[:-1] + (
248-
inputs_shape[-1] // block_size,
249-
block_size,
250-
)
251-
inputs = inputs.reshape(inputs_new_shape)
252-
inputs = jax.lax.collapse(inputs, 0, 2)
253-
out = jax.lax.dot_general(
254-
inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)])
255-
)
256-
out = jax.lax.dot_general(
257-
out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)])
258-
)
259-
out = jax.lax.transpose(out, [1, 0])
260-
out = out.reshape((bs, -1) + out.shape[1:])
261-
return out
262-
263-
@staticmethod
264-
def blockwise_jax_kernel_einsum_flatten(
265-
inputs, weight, weight_scaler, zero_point
266-
):
267-
"""Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
268-
weight = weight.astype(jnp.int8)
269-
block_size = weight.shape[1]
270-
inputs_shape = inputs.shape
271-
bs = inputs_shape[0]
272-
inputs_new_shape = inputs_shape[:-1] + (
273-
inputs_shape[-1] // block_size,
274-
block_size,
275-
)
276-
inputs = inputs.reshape(inputs_new_shape)
277-
inputs = jax.lax.collapse(inputs, 0, 2)
278-
out = jnp.einsum("scz,bsc->bsz", weight, inputs)
279-
out = jnp.einsum("bsz,sz->bz", out, weight_scaler)
280-
out = out.reshape((bs, -1) + out.shape[1:])
281-
return out
282-
283240
def forward(self, inputs):
284241
if not self.run_fake_quantize:
285-
if self.use_dot_general:
286-
assert (
287-
self.zero_point is None
288-
), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
289-
return torchjax.call_jax(
290-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_dot_general,
291-
inputs,
292-
self.weight,
293-
self.weight_scaler,
294-
self.zero_point,
295-
)
296-
if self.flatten:
297-
assert (
298-
self.zero_point is None
299-
), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
300-
return torchjax.call_jax(
301-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_einsum_flatten,
302-
inputs,
303-
self.weight,
304-
self.weight_scaler,
305-
self.zero_point,
306-
)
307-
else:
308-
return torchjax.call_jax(
309-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel,
310-
inputs,
311-
self.weight,
312-
self.weight_scaler,
313-
self.zero_point,
314-
)
242+
if self.use_dot_general or self.flatten:
243+
assert self.zero_point is None, "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
244+
blockwise_matmul_kernel = blockwise_jax_kernel if not self.use_dot_general and not self.flatten else blockwise_jax_kernel_dot_general if self.use_dot_general else blockwise_jax_kernel_einsum_flatten
245+
result = torchjax.call_jax(
246+
blockwise_matmul_kernel,
247+
inputs,
248+
self.weight,
249+
self.weight_scaler,
250+
self.zero_point,
251+
)
252+
return result
315253
else:
316254
# Fake quantization, debugging purpose.
317255
weight = self.weight.permute(2, 0, 1).to(torch.bfloat16)
318256
scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0)
319-
if not self.is_symmetric:
257+
if not self.is_symmetric_weight:
320258
zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler
321259
else:
322260
zero_point = None
@@ -554,12 +492,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
554492
self.hidden_size = hidden_size
555493

556494
LinearLayer = get_quantized_linear_layer(env.quant_config)
495+
linear_kwargs = {}
496+
if LinearLayer != torch.nn.Linear:
497+
linear_kwargs = {"quant_config": env.quant_config}
557498

558499
self.wo = LinearLayer(
559500
n_heads * self.head_dim,
560501
hidden_size,
561502
bias=False,
562503
device=device,
504+
**linear_kwargs,
563505
)
564506

565507
Kernel = (
@@ -578,25 +520,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
578520
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
579521
bias=False,
580522
device=device,
523+
**linear_kwargs,
581524
)
582525
else:
583526
self.wq = LinearLayer(
584527
hidden_size,
585528
n_heads * self.head_dim,
586529
bias=False,
587530
device=device,
531+
**linear_kwargs,
588532
)
589533
self.wk = LinearLayer(
590534
hidden_size,
591535
self.n_kv_heads * self.head_dim,
592536
bias=False,
593537
device=device,
538+
**linear_kwargs,
594539
)
595540
self.wv = LinearLayer(
596541
hidden_size,
597542
self.n_kv_heads * self.head_dim,
598543
bias=False,
599544
device=device,
545+
**linear_kwargs,
600546
)
601547

602548
def load_hook(self, state_dict, prefix, *args):

0 commit comments

Comments
 (0)