25
25
import torch_xla2
26
26
from jax import lax
27
27
from jetstream_pt import torchjax
28
+ from jetstream_pt .environment import QuantizationConfig
28
29
from jetstream_pt .quantize import (
29
30
dequantize_tensor ,
30
31
load_q_weight_helper ,
31
32
quantize_tensor ,
33
+ blockwise_jax_kernel ,
34
+ blockwise_jax_kernel_dot_general ,
35
+ blockwise_jax_kernel_einsum_flatten ,
32
36
)
33
37
from torch import nn
34
38
from . import attention_kernel as ak
@@ -68,8 +72,7 @@ def __init__(
68
72
out_features ,
69
73
bias = False ,
70
74
device = None ,
71
- is_symmetric = True ,
72
- n_bit = 8 ,
75
+ quant_config = QuantizationConfig (),
73
76
):
74
77
super ().__init__ ()
75
78
self .in_features = in_features
@@ -85,8 +88,9 @@ def __init__(
85
88
)
86
89
self .register_buffer ("weight_scaler" , weight_scaler )
87
90
88
- self .is_symmetric = is_symmetric
89
- if not is_symmetric :
91
+ self .is_symmetric_weight = quant_config .is_symmetric_weight
92
+
93
+ if not self .is_symmetric_weight :
90
94
zero_point = torch .ones (
91
95
(out_features ,), dtype = torch .bfloat16 , device = device
92
96
)
@@ -96,7 +100,12 @@ def __init__(
96
100
97
101
assert not bias , "Quantized Linear doesn't support bias."
98
102
99
- self .n_bit = n_bit
103
+ # Number of bits of weight tensor
104
+ self .n_bit = quant_config .num_bits_weight
105
+
106
+ # Quantize activation
107
+ self .quantize_activation = quant_config .enable_activation_quantization
108
+
100
109
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
101
110
self .run_fake_quantize = False
102
111
@@ -115,23 +124,40 @@ def quantize_weight_from_nn_linear(self, weight):
115
124
self .in_features ,
116
125
), f"Got unexpected weight of shape { weight .shape } , expected weight shape ({ self .out_features } , { self .in_features } )."
117
126
w_q , scale , zp = quantize_tensor (
118
- weight , (1 ,), self .n_bit , self .is_symmetric , block_size = - 1
127
+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , block_size = - 1
119
128
)
120
129
w_dq = dequantize_tensor (w_q , scale , zp )
121
130
self ._load_quantized_weights (w_q , scale , zp )
122
131
123
132
def forward (self , inputs ):
124
133
if not self .run_fake_quantize :
125
- if self .is_symmetric :
126
- return torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
134
+ if self .quantize_activation :
135
+ inputs , act_s , _ = quantize_tensor (inputs , reduce_axis = (2 ,))
136
+ if not self .quantize_activation :
137
+ result = F .linear (inputs , self .weight )
127
138
else :
128
- out = torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
139
+ # We have to call jax because we need to do dot(int8, int8)->int32.
140
+ # This semantic cannot be represented in torch. The inferred output dtype
141
+ # will be int8 in torch, causing the dot result to overflow.
142
+ result = torchjax .call_jax (
143
+ jax .lax .dot_general ,
144
+ inputs ,
145
+ self .weight ,
146
+ (((2 ,), (1 )), ((), ())),
147
+ None ,
148
+ jnp .int32 .dtype ,
149
+ )
150
+ result = result * self .weight_scaler
151
+ if self .quantize_activation :
152
+ result = result * act_s
153
+ if not self .is_symmetric_weight :
129
154
zp_out = torch .einsum ("...c,z->...z" , inputs , self .zero_point )
130
- return out - zp_out
155
+ result = result - zp_out
156
+ return result
131
157
else :
132
158
# Fake quantization, debugging purpose.
133
159
scaler = self .weight_scaler .unsqueeze (- 1 )
134
- if not self .is_symmetric :
160
+ if not self .is_symmetric_weight :
135
161
zero_point = self .zero_point .unsqueeze (- 1 ) / scaler
136
162
else :
137
163
zero_point = None
@@ -149,32 +175,37 @@ def __init__(
149
175
out_features ,
150
176
bias = False ,
151
177
device = None ,
152
- is_symmetric = True ,
153
- use_dot_general = False ,
154
- block_size = 128 ,
155
- n_bit = 8 ,
178
+ quant_config = QuantizationConfig (),
156
179
):
157
180
super ().__init__ ()
158
181
self .in_features = in_features
159
182
self .out_features = out_features
160
183
161
184
# Use dot general instead of einsum
162
185
# Use dot general is slow now.
163
- self .use_dot_general = use_dot_general
186
+ self .use_dot_general = False
164
187
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
165
188
# Same perf as non flattened one now.
166
189
self .flatten = False
167
190
168
- self .block_size = block_size
169
- n_blocks = in_features // block_size
191
+ self .block_size = quant_config .block_size_weight
192
+ n_blocks = in_features // self .block_size
193
+
194
+ assert (
195
+ not quant_config .enable_activation_quantization
196
+ ), "Activation quantization not supported for blockwise quantized matmul."
170
197
171
198
if self .use_dot_general :
172
199
weight = torch .ones (
173
- (n_blocks , out_features , block_size ), dtype = torch .int8 , device = device
200
+ (n_blocks , out_features , self .block_size ),
201
+ dtype = torch .int8 ,
202
+ device = device ,
174
203
)
175
204
else :
176
205
weight = torch .ones (
177
- (n_blocks , block_size , out_features ), dtype = torch .int8 , device = device
206
+ (n_blocks , self .block_size , out_features ),
207
+ dtype = torch .int8 ,
208
+ device = device ,
178
209
)
179
210
self .register_buffer ("weight" , weight )
180
211
@@ -183,16 +214,20 @@ def __init__(
183
214
)
184
215
self .register_buffer ("weight_scaler" , weight_scaler )
185
216
186
- self .is_symmetric = is_symmetric
187
- if not self .is_symmetric :
217
+ self .is_symmetric_weight = quant_config . is_symmetric_weight
218
+ if not self .is_symmetric_weight :
188
219
zero_point = torch .ones (
189
220
(n_blocks , out_features ), dtype = torch .bfloat16 , device = device
190
221
)
191
222
self .register_buffer ("zero_point" , zero_point )
192
223
else :
193
224
self .register_buffer ("zero_point" , None )
194
225
195
- self .n_bit = n_bit
226
+ self .n_bit = quant_config .num_bits_weight
227
+
228
+ # Quantize activation
229
+ self .quantize_activation = quant_config .enable_activation_quantization
230
+
196
231
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
197
232
self .run_fake_quantize = False
198
233
@@ -211,112 +246,37 @@ def quantize_weight_from_nn_linear(self, weight):
211
246
self .in_features ,
212
247
), f"Unexpected weight shape ({ self .out_features } , { self .in_features } )."
213
248
w_q , scale , zp = quantize_tensor (
214
- weight , (1 ,), self .n_bit , self .is_symmetric , self .block_size
249
+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , self .block_size
215
250
)
216
251
w_dq = dequantize_tensor (w_q , scale , zp )
217
- print ("check qweight cosine dist: " , _calc_cosine_dist (weight , w_dq ))
218
- # breakpoint()
219
252
self ._load_quantized_weights (w_q , scale , zp )
220
253
221
- @staticmethod
222
- def blockwise_jax_kernel (inputs , weight , weight_scaler , zero_point ):
223
- """Blockwise Matmul kernel impl in JAX using einsum"""
224
- weight = weight .astype (jnp .int8 )
225
- block_size = weight .shape [1 ]
226
- inputs_shape = inputs .shape
227
- inputs_new_shape = inputs_shape [:- 1 ] + (
228
- inputs_shape [- 1 ] // block_size ,
229
- block_size ,
230
- )
231
- inputs = inputs .reshape (inputs_new_shape )
232
- out = jnp .einsum ("scz,bdsc->bdsz" , weight , inputs )
233
- out = jnp .einsum ("bdsz,sz->bdz" , out , weight_scaler )
234
- if zero_point is not None :
235
- zp_out = jnp .einsum ("bdsc,sz->bdz" , inputs , zero_point )
236
- out = out - zp_out
237
- return out
238
-
239
- @staticmethod
240
- def blockwise_jax_kernel_dot_general (
241
- inputs , weight , weight_scaler , zero_point
242
- ):
243
- """Blockwise Matmul kernel impl in JAX using dot general"""
244
- inputs_shape = inputs .shape
245
- block_size = weight .shape [2 ]
246
- bs = inputs_shape [0 ]
247
- inputs_new_shape = inputs_shape [:- 1 ] + (
248
- inputs_shape [- 1 ] // block_size ,
249
- block_size ,
250
- )
251
- inputs = inputs .reshape (inputs_new_shape )
252
- inputs = jax .lax .collapse (inputs , 0 , 2 )
253
- out = jax .lax .dot_general (
254
- inputs , weight , dimension_numbers = ([(2 ), (2 )], [(1 ), (0 )])
255
- )
256
- out = jax .lax .dot_general (
257
- out , weight_scaler , dimension_numbers = ([(0 ), (0 )], [(2 ), (1 )])
258
- )
259
- out = jax .lax .transpose (out , [1 , 0 ])
260
- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
261
- return out
262
-
263
- @staticmethod
264
- def blockwise_jax_kernel_einsum_flatten (
265
- inputs , weight , weight_scaler , zero_point
266
- ):
267
- """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
268
- weight = weight .astype (jnp .int8 )
269
- block_size = weight .shape [1 ]
270
- inputs_shape = inputs .shape
271
- bs = inputs_shape [0 ]
272
- inputs_new_shape = inputs_shape [:- 1 ] + (
273
- inputs_shape [- 1 ] // block_size ,
274
- block_size ,
275
- )
276
- inputs = inputs .reshape (inputs_new_shape )
277
- inputs = jax .lax .collapse (inputs , 0 , 2 )
278
- out = jnp .einsum ("scz,bsc->bsz" , weight , inputs )
279
- out = jnp .einsum ("bsz,sz->bz" , out , weight_scaler )
280
- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
281
- return out
282
-
283
254
def forward (self , inputs ):
284
255
if not self .run_fake_quantize :
285
- if self .use_dot_general :
256
+ if self .use_dot_general or self . flatten :
286
257
assert (
287
258
self .zero_point is None
288
- ), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
289
- return torchjax .call_jax (
290
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_dot_general ,
291
- inputs ,
292
- self .weight ,
293
- self .weight_scaler ,
294
- self .zero_point ,
295
- )
296
- if self .flatten :
297
- assert (
298
- self .zero_point is None
299
- ), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
300
- return torchjax .call_jax (
301
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_einsum_flatten ,
302
- inputs ,
303
- self .weight ,
304
- self .weight_scaler ,
305
- self .zero_point ,
306
- )
307
- else :
308
- return torchjax .call_jax (
309
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel ,
310
- inputs ,
311
- self .weight ,
312
- self .weight_scaler ,
313
- self .zero_point ,
314
- )
259
+ ), "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
260
+ blockwise_matmul_kernel = (
261
+ blockwise_jax_kernel
262
+ if not self .use_dot_general and not self .flatten
263
+ else blockwise_jax_kernel_dot_general
264
+ if self .use_dot_general
265
+ else blockwise_jax_kernel_einsum_flatten
266
+ )
267
+ result = torchjax .call_jax (
268
+ blockwise_matmul_kernel ,
269
+ inputs ,
270
+ self .weight ,
271
+ self .weight_scaler ,
272
+ self .zero_point ,
273
+ )
274
+ return result
315
275
else :
316
276
# Fake quantization, debugging purpose.
317
277
weight = self .weight .permute (2 , 0 , 1 ).to (torch .bfloat16 )
318
278
scaler = self .weight_scaler .unsqueeze (- 1 ).transpose (1 , 0 )
319
- if not self .is_symmetric :
279
+ if not self .is_symmetric_weight :
320
280
zero_point = self .zero_point .unsqueeze (- 1 ).transpose (1 , 0 ) / scaler
321
281
else :
322
282
zero_point = None
@@ -554,12 +514,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
554
514
self .hidden_size = hidden_size
555
515
556
516
LinearLayer = get_quantized_linear_layer (env .quant_config )
517
+ linear_kwargs = {}
518
+ if LinearLayer != torch .nn .Linear :
519
+ linear_kwargs = {"quant_config" : env .quant_config }
557
520
558
521
self .wo = LinearLayer (
559
522
n_heads * self .head_dim ,
560
523
hidden_size ,
561
524
bias = False ,
562
525
device = device ,
526
+ ** linear_kwargs ,
563
527
)
564
528
565
529
Kernel = (
@@ -578,25 +542,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
578
542
(n_heads + 2 * self .n_kv_heads ) * self .head_dim ,
579
543
bias = False ,
580
544
device = device ,
545
+ ** linear_kwargs ,
581
546
)
582
547
else :
583
548
self .wq = LinearLayer (
584
549
hidden_size ,
585
550
n_heads * self .head_dim ,
586
551
bias = False ,
587
552
device = device ,
553
+ ** linear_kwargs ,
588
554
)
589
555
self .wk = LinearLayer (
590
556
hidden_size ,
591
557
self .n_kv_heads * self .head_dim ,
592
558
bias = False ,
593
559
device = device ,
560
+ ** linear_kwargs ,
594
561
)
595
562
self .wv = LinearLayer (
596
563
hidden_size ,
597
564
self .n_kv_heads * self .head_dim ,
598
565
bias = False ,
599
566
device = device ,
567
+ ** linear_kwargs ,
600
568
)
601
569
602
570
def load_hook (self , state_dict , prefix , * args ):
0 commit comments