Skip to content

Commit 8a125b6

Browse files
authored
Add activation quantization support to per-channel quantized linear layers (#105)
* add activation quant support * pyink * fix dtype * uncomment prompts * try fix test add debug print to debug remove print, add bias to asym quant tests lint * add comment
1 parent e2ee7dd commit 8a125b6

File tree

7 files changed

+253
-147
lines changed

7 files changed

+253
-147
lines changed

jetstream_pt/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@
4242

4343
# Quantization related flags
4444
flags.DEFINE_bool("quantize_weights", False, "weight quantization")
45+
flags.DEFINE_bool(
46+
"quantize_activation",
47+
False,
48+
"Quantize Q,K,V projection and FeedForward activation.",
49+
)
4550
flags.DEFINE_string(
4651
"quantize_type", "int8_per_channel", "Type of quantization."
4752
)
@@ -90,6 +95,9 @@ def create_quantization_config_from_flags():
9095
config.enable_weight_quantization = True
9196
config.num_bits_weight = 8 if "int8" in quantize_type else 4
9297
config.is_blockwise_weight = "blockwise" in quantize_type
98+
99+
config.enable_activation_quantization = FLAGS.quantize_activation
100+
93101
config.enable_kv_quantization = FLAGS.quantize_kv_cache
94102
return config
95103

jetstream_pt/environment.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ class QuantizationConfig:
3232
enable_weight_quantization: bool = False
3333
num_bits_weight: int = 8
3434
is_blockwise_weight: bool = False
35+
block_size_weight: int = 128
36+
is_symmetric_weight: bool = True
37+
38+
enable_activation_quantization: bool = False
3539

3640
enable_kv_quantization: bool = False
3741

jetstream_pt/layers.py

Lines changed: 85 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@
2525
import torch_xla2
2626
from jax import lax
2727
from jetstream_pt import torchjax
28+
from jetstream_pt.environment import QuantizationConfig
2829
from jetstream_pt.quantize import (
2930
dequantize_tensor,
3031
load_q_weight_helper,
3132
quantize_tensor,
33+
blockwise_jax_kernel,
34+
blockwise_jax_kernel_dot_general,
35+
blockwise_jax_kernel_einsum_flatten,
3236
)
3337
from torch import nn
3438
from . import attention_kernel as ak
@@ -68,8 +72,7 @@ def __init__(
6872
out_features,
6973
bias=False,
7074
device=None,
71-
is_symmetric=True,
72-
n_bit=8,
75+
quant_config=QuantizationConfig(),
7376
):
7477
super().__init__()
7578
self.in_features = in_features
@@ -85,8 +88,9 @@ def __init__(
8588
)
8689
self.register_buffer("weight_scaler", weight_scaler)
8790

88-
self.is_symmetric = is_symmetric
89-
if not is_symmetric:
91+
self.is_symmetric_weight = quant_config.is_symmetric_weight
92+
93+
if not self.is_symmetric_weight:
9094
zero_point = torch.ones(
9195
(out_features,), dtype=torch.bfloat16, device=device
9296
)
@@ -96,7 +100,12 @@ def __init__(
96100

97101
assert not bias, "Quantized Linear doesn't support bias."
98102

99-
self.n_bit = n_bit
103+
# Number of bits of weight tensor
104+
self.n_bit = quant_config.num_bits_weight
105+
106+
# Quantize activation
107+
self.quantize_activation = quant_config.enable_activation_quantization
108+
100109
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
101110
self.run_fake_quantize = False
102111

@@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight):
115124
self.in_features,
116125
), f"Got unexpected weight of shape {weight.shape}, expected weight shape ({self.out_features}, {self.in_features})."
117126
w_q, scale, zp = quantize_tensor(
118-
weight, (1,), self.n_bit, self.is_symmetric, block_size=-1
127+
weight, (1,), self.n_bit, self.is_symmetric_weight, block_size=-1
119128
)
120129
w_dq = dequantize_tensor(w_q, scale, zp)
121130
self._load_quantized_weights(w_q, scale, zp)
122131

123132
def forward(self, inputs):
124133
if not self.run_fake_quantize:
125-
if self.is_symmetric:
126-
return torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
134+
if self.quantize_activation:
135+
inputs, act_s, _ = quantize_tensor(inputs, reduce_axis=(2,))
136+
if not self.quantize_activation:
137+
result = F.linear(inputs, self.weight)
127138
else:
128-
out = torch.mul(F.linear(inputs, self.weight), self.weight_scaler)
139+
# We have to call jax because we need to do dot(int8, int8)->int32.
140+
# This semantic cannot be represented in torch. The inferred output dtype
141+
# will be int8 in torch, causing the dot result to overflow.
142+
result = torchjax.call_jax(
143+
jax.lax.dot_general,
144+
inputs,
145+
self.weight,
146+
(((2,), (1)), ((), ())),
147+
None,
148+
jnp.int32.dtype,
149+
)
150+
result = result * self.weight_scaler
151+
if self.quantize_activation:
152+
result = result * act_s
153+
if not self.is_symmetric_weight:
129154
zp_out = torch.einsum("...c,z->...z", inputs, self.zero_point)
130-
return out - zp_out
155+
result = result - zp_out
156+
return result
131157
else:
132158
# Fake quantization, debugging purpose.
133159
scaler = self.weight_scaler.unsqueeze(-1)
134-
if not self.is_symmetric:
160+
if not self.is_symmetric_weight:
135161
zero_point = self.zero_point.unsqueeze(-1) / scaler
136162
else:
137163
zero_point = None
@@ -149,32 +175,37 @@ def __init__(
149175
out_features,
150176
bias=False,
151177
device=None,
152-
is_symmetric=True,
153-
use_dot_general=False,
154-
block_size=128,
155-
n_bit=8,
178+
quant_config=QuantizationConfig(),
156179
):
157180
super().__init__()
158181
self.in_features = in_features
159182
self.out_features = out_features
160183

161184
# Use dot general instead of einsum
162185
# Use dot general is slow now.
163-
self.use_dot_general = use_dot_general
186+
self.use_dot_general = False
164187
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
165188
# Same perf as non flattened one now.
166189
self.flatten = False
167190

168-
self.block_size = block_size
169-
n_blocks = in_features // block_size
191+
self.block_size = quant_config.block_size_weight
192+
n_blocks = in_features // self.block_size
193+
194+
assert (
195+
not quant_config.enable_activation_quantization
196+
), "Activation quantization not supported for blockwise quantized matmul."
170197

171198
if self.use_dot_general:
172199
weight = torch.ones(
173-
(n_blocks, out_features, block_size), dtype=torch.int8, device=device
200+
(n_blocks, out_features, self.block_size),
201+
dtype=torch.int8,
202+
device=device,
174203
)
175204
else:
176205
weight = torch.ones(
177-
(n_blocks, block_size, out_features), dtype=torch.int8, device=device
206+
(n_blocks, self.block_size, out_features),
207+
dtype=torch.int8,
208+
device=device,
178209
)
179210
self.register_buffer("weight", weight)
180211

@@ -183,16 +214,20 @@ def __init__(
183214
)
184215
self.register_buffer("weight_scaler", weight_scaler)
185216

186-
self.is_symmetric = is_symmetric
187-
if not self.is_symmetric:
217+
self.is_symmetric_weight = quant_config.is_symmetric_weight
218+
if not self.is_symmetric_weight:
188219
zero_point = torch.ones(
189220
(n_blocks, out_features), dtype=torch.bfloat16, device=device
190221
)
191222
self.register_buffer("zero_point", zero_point)
192223
else:
193224
self.register_buffer("zero_point", None)
194225

195-
self.n_bit = n_bit
226+
self.n_bit = quant_config.num_bits_weight
227+
228+
# Quantize activation
229+
self.quantize_activation = quant_config.enable_activation_quantization
230+
196231
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
197232
self.run_fake_quantize = False
198233

@@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight):
211246
self.in_features,
212247
), f"Unexpected weight shape ({self.out_features}, {self.in_features})."
213248
w_q, scale, zp = quantize_tensor(
214-
weight, (1,), self.n_bit, self.is_symmetric, self.block_size
249+
weight, (1,), self.n_bit, self.is_symmetric_weight, self.block_size
215250
)
216251
w_dq = dequantize_tensor(w_q, scale, zp)
217-
print("check qweight cosine dist: ", _calc_cosine_dist(weight, w_dq))
218-
# breakpoint()
219252
self._load_quantized_weights(w_q, scale, zp)
220253

221-
@staticmethod
222-
def blockwise_jax_kernel(inputs, weight, weight_scaler, zero_point):
223-
"""Blockwise Matmul kernel impl in JAX using einsum"""
224-
weight = weight.astype(jnp.int8)
225-
block_size = weight.shape[1]
226-
inputs_shape = inputs.shape
227-
inputs_new_shape = inputs_shape[:-1] + (
228-
inputs_shape[-1] // block_size,
229-
block_size,
230-
)
231-
inputs = inputs.reshape(inputs_new_shape)
232-
out = jnp.einsum("scz,bdsc->bdsz", weight, inputs)
233-
out = jnp.einsum("bdsz,sz->bdz", out, weight_scaler)
234-
if zero_point is not None:
235-
zp_out = jnp.einsum("bdsc,sz->bdz", inputs, zero_point)
236-
out = out - zp_out
237-
return out
238-
239-
@staticmethod
240-
def blockwise_jax_kernel_dot_general(
241-
inputs, weight, weight_scaler, zero_point
242-
):
243-
"""Blockwise Matmul kernel impl in JAX using dot general"""
244-
inputs_shape = inputs.shape
245-
block_size = weight.shape[2]
246-
bs = inputs_shape[0]
247-
inputs_new_shape = inputs_shape[:-1] + (
248-
inputs_shape[-1] // block_size,
249-
block_size,
250-
)
251-
inputs = inputs.reshape(inputs_new_shape)
252-
inputs = jax.lax.collapse(inputs, 0, 2)
253-
out = jax.lax.dot_general(
254-
inputs, weight, dimension_numbers=([(2), (2)], [(1), (0)])
255-
)
256-
out = jax.lax.dot_general(
257-
out, weight_scaler, dimension_numbers=([(0), (0)], [(2), (1)])
258-
)
259-
out = jax.lax.transpose(out, [1, 0])
260-
out = out.reshape((bs, -1) + out.shape[1:])
261-
return out
262-
263-
@staticmethod
264-
def blockwise_jax_kernel_einsum_flatten(
265-
inputs, weight, weight_scaler, zero_point
266-
):
267-
"""Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
268-
weight = weight.astype(jnp.int8)
269-
block_size = weight.shape[1]
270-
inputs_shape = inputs.shape
271-
bs = inputs_shape[0]
272-
inputs_new_shape = inputs_shape[:-1] + (
273-
inputs_shape[-1] // block_size,
274-
block_size,
275-
)
276-
inputs = inputs.reshape(inputs_new_shape)
277-
inputs = jax.lax.collapse(inputs, 0, 2)
278-
out = jnp.einsum("scz,bsc->bsz", weight, inputs)
279-
out = jnp.einsum("bsz,sz->bz", out, weight_scaler)
280-
out = out.reshape((bs, -1) + out.shape[1:])
281-
return out
282-
283254
def forward(self, inputs):
284255
if not self.run_fake_quantize:
285-
if self.use_dot_general:
256+
if self.use_dot_general or self.flatten:
286257
assert (
287258
self.zero_point is None
288-
), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
289-
return torchjax.call_jax(
290-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_dot_general,
291-
inputs,
292-
self.weight,
293-
self.weight_scaler,
294-
self.zero_point,
295-
)
296-
if self.flatten:
297-
assert (
298-
self.zero_point is None
299-
), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
300-
return torchjax.call_jax(
301-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel_einsum_flatten,
302-
inputs,
303-
self.weight,
304-
self.weight_scaler,
305-
self.zero_point,
306-
)
307-
else:
308-
return torchjax.call_jax(
309-
WeightOnlyBlockwiseQuantizedLinear.blockwise_jax_kernel,
310-
inputs,
311-
self.weight,
312-
self.weight_scaler,
313-
self.zero_point,
314-
)
259+
), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
260+
blockwise_matmul_kernel = (
261+
blockwise_jax_kernel
262+
if not self.use_dot_general and not self.flatten
263+
else blockwise_jax_kernel_dot_general
264+
if self.use_dot_general
265+
else blockwise_jax_kernel_einsum_flatten
266+
)
267+
result = torchjax.call_jax(
268+
blockwise_matmul_kernel,
269+
inputs,
270+
self.weight,
271+
self.weight_scaler,
272+
self.zero_point,
273+
)
274+
return result
315275
else:
316276
# Fake quantization, debugging purpose.
317277
weight = self.weight.permute(2, 0, 1).to(torch.bfloat16)
318278
scaler = self.weight_scaler.unsqueeze(-1).transpose(1, 0)
319-
if not self.is_symmetric:
279+
if not self.is_symmetric_weight:
320280
zero_point = self.zero_point.unsqueeze(-1).transpose(1, 0) / scaler
321281
else:
322282
zero_point = None
@@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
554514
self.hidden_size = hidden_size
555515

556516
LinearLayer = get_quantized_linear_layer(env.quant_config)
517+
linear_kwargs = {}
518+
if LinearLayer != torch.nn.Linear:
519+
linear_kwargs = {"quant_config": env.quant_config}
557520

558521
self.wo = LinearLayer(
559522
n_heads * self.head_dim,
560523
hidden_size,
561524
bias=False,
562525
device=device,
526+
**linear_kwargs,
563527
)
564528

565529
Kernel = (
@@ -578,25 +542,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
578542
(n_heads + 2 * self.n_kv_heads) * self.head_dim,
579543
bias=False,
580544
device=device,
545+
**linear_kwargs,
581546
)
582547
else:
583548
self.wq = LinearLayer(
584549
hidden_size,
585550
n_heads * self.head_dim,
586551
bias=False,
587552
device=device,
553+
**linear_kwargs,
588554
)
589555
self.wk = LinearLayer(
590556
hidden_size,
591557
self.n_kv_heads * self.head_dim,
592558
bias=False,
593559
device=device,
560+
**linear_kwargs,
594561
)
595562
self.wv = LinearLayer(
596563
hidden_size,
597564
self.n_kv_heads * self.head_dim,
598565
bias=False,
599566
device=device,
567+
**linear_kwargs,
600568
)
601569

602570
def load_hook(self, state_dict, prefix, *args):

0 commit comments

Comments
 (0)