25
25
import torch_xla2
26
26
from jax import lax
27
27
from jetstream_pt import torchjax
28
+ from jetstream_pt .environment import QuantizationConfig
28
29
from jetstream_pt .quantize import (
29
30
dequantize_tensor ,
30
31
load_q_weight_helper ,
31
32
quantize_tensor ,
33
+ blockwise_jax_kernel ,
34
+ blockwise_jax_kernel_dot_general ,
35
+ blockwise_jax_kernel_einsum_flatten ,
32
36
)
33
37
from torch import nn
34
38
from . import attention_kernel as ak
@@ -68,8 +72,7 @@ def __init__(
68
72
out_features ,
69
73
bias = False ,
70
74
device = None ,
71
- is_symmetric = True ,
72
- n_bit = 8 ,
75
+ quant_config = QuantizationConfig (),
73
76
):
74
77
super ().__init__ ()
75
78
self .in_features = in_features
@@ -85,8 +88,9 @@ def __init__(
85
88
)
86
89
self .register_buffer ("weight_scaler" , weight_scaler )
87
90
88
- self .is_symmetric = is_symmetric
89
- if not is_symmetric :
91
+ self .is_symmetric_weight = quant_config .is_symmetric_weight
92
+
93
+ if not self .is_symmetric_weight :
90
94
zero_point = torch .ones (
91
95
(out_features ,), dtype = torch .bfloat16 , device = device
92
96
)
@@ -96,7 +100,12 @@ def __init__(
96
100
97
101
assert not bias , "Quantized Linear doesn't support bias."
98
102
99
- self .n_bit = n_bit
103
+ # Number of bits of weight tensor
104
+ self .n_bit = quant_config .num_bits_weight
105
+
106
+ # Quantize activation
107
+ self .quantize_activation = quant_config .enable_activation_quantization
108
+
100
109
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
101
110
self .run_fake_quantize = False
102
111
@@ -115,23 +124,31 @@ def quantize_weight_from_nn_linear(self, weight):
115
124
self .in_features ,
116
125
), f"Got unexpected weight of shape { weight .shape } , expected weight shape ({ self .out_features } , { self .in_features } )."
117
126
w_q , scale , zp = quantize_tensor (
118
- weight , (1 ,), self .n_bit , self .is_symmetric , block_size = - 1
127
+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , block_size = - 1
119
128
)
120
129
w_dq = dequantize_tensor (w_q , scale , zp )
121
130
self ._load_quantized_weights (w_q , scale , zp )
122
131
123
132
def forward (self , inputs ):
124
133
if not self .run_fake_quantize :
125
- if self .is_symmetric :
126
- return torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
134
+ if self .quantize_activation :
135
+ inputs , act_s , _ = quantize_tensor (inputs , reduce_axis = (2 ,))
136
+ if not self .quantize_activation :
137
+ result = F .linear (inputs , self .weight )
127
138
else :
128
- out = torch .mul (F .linear (inputs , self .weight ), self .weight_scaler )
139
+ result = torchjax .call_jax (jax .lax .dot_general , inputs , self .weight ,
140
+ (((2 ,),(1 )),((),())), None , torch .int32 )
141
+ result = result * self .weight_scaler
142
+ if self .quantize_activation :
143
+ result = result * act_s
144
+ if not self .is_symmetric_weight :
129
145
zp_out = torch .einsum ("...c,z->...z" , inputs , self .zero_point )
130
- return out - zp_out
146
+ result = result - zp_out
147
+ return result
131
148
else :
132
149
# Fake quantization, debugging purpose.
133
150
scaler = self .weight_scaler .unsqueeze (- 1 )
134
- if not self .is_symmetric :
151
+ if not self .is_symmetric_weight :
135
152
zero_point = self .zero_point .unsqueeze (- 1 ) / scaler
136
153
else :
137
154
zero_point = None
@@ -149,32 +166,31 @@ def __init__(
149
166
out_features ,
150
167
bias = False ,
151
168
device = None ,
152
- is_symmetric = True ,
153
- use_dot_general = False ,
154
- block_size = 128 ,
155
- n_bit = 8 ,
169
+ quant_config = QuantizationConfig (),
156
170
):
157
171
super ().__init__ ()
158
172
self .in_features = in_features
159
173
self .out_features = out_features
160
174
161
175
# Use dot general instead of einsum
162
176
# Use dot general is slow now.
163
- self .use_dot_general = use_dot_general
177
+ self .use_dot_general = False
164
178
# Flatten einsum operands to 3D. XLA was slow if operands are 4D. But it's fixed now.
165
179
# Same perf as non flattened one now.
166
180
self .flatten = False
167
181
168
- self .block_size = block_size
169
- n_blocks = in_features // block_size
182
+ self .block_size = quant_config . block_size_weight
183
+ n_blocks = in_features // self . block_size
170
184
185
+ assert not quant_config .enable_activation_quantization , "Activation quantization not supported for blockwise quantized matmul."
186
+
171
187
if self .use_dot_general :
172
188
weight = torch .ones (
173
- (n_blocks , out_features , block_size ), dtype = torch .int8 , device = device
189
+ (n_blocks , out_features , self . block_size ), dtype = torch .int8 , device = device
174
190
)
175
191
else :
176
192
weight = torch .ones (
177
- (n_blocks , block_size , out_features ), dtype = torch .int8 , device = device
193
+ (n_blocks , self . block_size , out_features ), dtype = torch .int8 , device = device
178
194
)
179
195
self .register_buffer ("weight" , weight )
180
196
@@ -183,16 +199,20 @@ def __init__(
183
199
)
184
200
self .register_buffer ("weight_scaler" , weight_scaler )
185
201
186
- self .is_symmetric = is_symmetric
187
- if not self .is_symmetric :
202
+ self .is_symmetric_weight = quant_config . is_symmetric_weight
203
+ if not self .is_symmetric_weight :
188
204
zero_point = torch .ones (
189
205
(n_blocks , out_features ), dtype = torch .bfloat16 , device = device
190
206
)
191
207
self .register_buffer ("zero_point" , zero_point )
192
208
else :
193
209
self .register_buffer ("zero_point" , None )
194
210
195
- self .n_bit = n_bit
211
+ self .n_bit = quant_config .num_bits_weight
212
+
213
+ # Quantize activation
214
+ self .quantize_activation = quant_config .enable_activation_quantization
215
+
196
216
# Flag to enable dequantize weight first, then do matmul. Useful for debugging.
197
217
self .run_fake_quantize = False
198
218
@@ -211,112 +231,30 @@ def quantize_weight_from_nn_linear(self, weight):
211
231
self .in_features ,
212
232
), f"Unexpected weight shape ({ self .out_features } , { self .in_features } )."
213
233
w_q , scale , zp = quantize_tensor (
214
- weight , (1 ,), self .n_bit , self .is_symmetric , self .block_size
234
+ weight , (1 ,), self .n_bit , self .is_symmetric_weight , self .block_size
215
235
)
216
236
w_dq = dequantize_tensor (w_q , scale , zp )
217
237
print ("check qweight cosine dist: " , _calc_cosine_dist (weight , w_dq ))
218
- # breakpoint()
219
238
self ._load_quantized_weights (w_q , scale , zp )
220
239
221
- @staticmethod
222
- def blockwise_jax_kernel (inputs , weight , weight_scaler , zero_point ):
223
- """Blockwise Matmul kernel impl in JAX using einsum"""
224
- weight = weight .astype (jnp .int8 )
225
- block_size = weight .shape [1 ]
226
- inputs_shape = inputs .shape
227
- inputs_new_shape = inputs_shape [:- 1 ] + (
228
- inputs_shape [- 1 ] // block_size ,
229
- block_size ,
230
- )
231
- inputs = inputs .reshape (inputs_new_shape )
232
- out = jnp .einsum ("scz,bdsc->bdsz" , weight , inputs )
233
- out = jnp .einsum ("bdsz,sz->bdz" , out , weight_scaler )
234
- if zero_point is not None :
235
- zp_out = jnp .einsum ("bdsc,sz->bdz" , inputs , zero_point )
236
- out = out - zp_out
237
- return out
238
-
239
- @staticmethod
240
- def blockwise_jax_kernel_dot_general (
241
- inputs , weight , weight_scaler , zero_point
242
- ):
243
- """Blockwise Matmul kernel impl in JAX using dot general"""
244
- inputs_shape = inputs .shape
245
- block_size = weight .shape [2 ]
246
- bs = inputs_shape [0 ]
247
- inputs_new_shape = inputs_shape [:- 1 ] + (
248
- inputs_shape [- 1 ] // block_size ,
249
- block_size ,
250
- )
251
- inputs = inputs .reshape (inputs_new_shape )
252
- inputs = jax .lax .collapse (inputs , 0 , 2 )
253
- out = jax .lax .dot_general (
254
- inputs , weight , dimension_numbers = ([(2 ), (2 )], [(1 ), (0 )])
255
- )
256
- out = jax .lax .dot_general (
257
- out , weight_scaler , dimension_numbers = ([(0 ), (0 )], [(2 ), (1 )])
258
- )
259
- out = jax .lax .transpose (out , [1 , 0 ])
260
- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
261
- return out
262
-
263
- @staticmethod
264
- def blockwise_jax_kernel_einsum_flatten (
265
- inputs , weight , weight_scaler , zero_point
266
- ):
267
- """Blockwise Matmul kernel impl in JAX using einsum, with operands flattened"""
268
- weight = weight .astype (jnp .int8 )
269
- block_size = weight .shape [1 ]
270
- inputs_shape = inputs .shape
271
- bs = inputs_shape [0 ]
272
- inputs_new_shape = inputs_shape [:- 1 ] + (
273
- inputs_shape [- 1 ] // block_size ,
274
- block_size ,
275
- )
276
- inputs = inputs .reshape (inputs_new_shape )
277
- inputs = jax .lax .collapse (inputs , 0 , 2 )
278
- out = jnp .einsum ("scz,bsc->bsz" , weight , inputs )
279
- out = jnp .einsum ("bsz,sz->bz" , out , weight_scaler )
280
- out = out .reshape ((bs , - 1 ) + out .shape [1 :])
281
- return out
282
-
283
240
def forward (self , inputs ):
284
241
if not self .run_fake_quantize :
285
- if self .use_dot_general :
286
- assert (
287
- self .zero_point is None
288
- ), "Blockwise quantized linear doesn't support zero_point in dot_general implementation."
289
- return torchjax .call_jax (
290
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_dot_general ,
291
- inputs ,
292
- self .weight ,
293
- self .weight_scaler ,
294
- self .zero_point ,
295
- )
296
- if self .flatten :
297
- assert (
298
- self .zero_point is None
299
- ), "Blockwise quantized linear doesn't support zero_point in einsum (flattened) implementation."
300
- return torchjax .call_jax (
301
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel_einsum_flatten ,
302
- inputs ,
303
- self .weight ,
304
- self .weight_scaler ,
305
- self .zero_point ,
306
- )
307
- else :
308
- return torchjax .call_jax (
309
- WeightOnlyBlockwiseQuantizedLinear .blockwise_jax_kernel ,
310
- inputs ,
311
- self .weight ,
312
- self .weight_scaler ,
313
- self .zero_point ,
314
- )
242
+ if self .use_dot_general or self .flatten :
243
+ assert self .zero_point is None , "Blockwise quantized linear doesn't support zero_point in dot_general or einsum flattened implementation."
244
+ blockwise_matmul_kernel = blockwise_jax_kernel if not self .use_dot_general and not self .flatten else blockwise_jax_kernel_dot_general if self .use_dot_general else blockwise_jax_kernel_einsum_flatten
245
+ result = torchjax .call_jax (
246
+ blockwise_matmul_kernel ,
247
+ inputs ,
248
+ self .weight ,
249
+ self .weight_scaler ,
250
+ self .zero_point ,
251
+ )
252
+ return result
315
253
else :
316
254
# Fake quantization, debugging purpose.
317
255
weight = self .weight .permute (2 , 0 , 1 ).to (torch .bfloat16 )
318
256
scaler = self .weight_scaler .unsqueeze (- 1 ).transpose (1 , 0 )
319
- if not self .is_symmetric :
257
+ if not self .is_symmetric_weight :
320
258
zero_point = self .zero_point .unsqueeze (- 1 ).transpose (1 , 0 ) / scaler
321
259
else :
322
260
zero_point = None
@@ -554,12 +492,16 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
554
492
self .hidden_size = hidden_size
555
493
556
494
LinearLayer = get_quantized_linear_layer (env .quant_config )
495
+ linear_kwargs = {}
496
+ if LinearLayer != torch .nn .Linear :
497
+ linear_kwargs = {"quant_config" : env .quant_config }
557
498
558
499
self .wo = LinearLayer (
559
500
n_heads * self .head_dim ,
560
501
hidden_size ,
561
502
bias = False ,
562
503
device = device ,
504
+ ** linear_kwargs ,
563
505
)
564
506
565
507
Kernel = (
@@ -578,25 +520,29 @@ def __init__(self, n_heads, n_kv_heads, head_dim, hidden_size, device, env):
578
520
(n_heads + 2 * self .n_kv_heads ) * self .head_dim ,
579
521
bias = False ,
580
522
device = device ,
523
+ ** linear_kwargs ,
581
524
)
582
525
else :
583
526
self .wq = LinearLayer (
584
527
hidden_size ,
585
528
n_heads * self .head_dim ,
586
529
bias = False ,
587
530
device = device ,
531
+ ** linear_kwargs ,
588
532
)
589
533
self .wk = LinearLayer (
590
534
hidden_size ,
591
535
self .n_kv_heads * self .head_dim ,
592
536
bias = False ,
593
537
device = device ,
538
+ ** linear_kwargs ,
594
539
)
595
540
self .wv = LinearLayer (
596
541
hidden_size ,
597
542
self .n_kv_heads * self .head_dim ,
598
543
bias = False ,
599
544
device = device ,
545
+ ** linear_kwargs ,
600
546
)
601
547
602
548
def load_hook (self , state_dict , prefix , * args ):
0 commit comments