1
+ import os
1
2
from typing import List
2
3
import random
3
4
import sys
@@ -39,10 +40,8 @@ def shard_weights(env, weights, weight_shardings):
39
40
sharded = {}
40
41
for key , val in weights .items ():
41
42
sharding = env .sharding_by_axis (weight_shardings .get (key , - 1 ))
42
- print ("SHARDING" , key , sharding )
43
43
with jax .default_device (jax .devices ("cpu" )[0 ]):
44
44
arr = torch_xla2 .tensor .t2j (val )
45
-
46
45
arr = jax .device_put (arr , sharding )
47
46
sharded [key ] = torchjax .to_torch (arr )
48
47
return sharded
@@ -57,17 +56,16 @@ def create_engine(devices):
57
56
FLAGS .override_batch_size ,
58
57
FLAGS .max_input_length ,
59
58
FLAGS .max_output_length ,
60
- quant_config .enable_weight_quantization ,
61
59
)
62
60
tokenizer = AutoTokenizer .from_pretrained (FLAGS .model_id )
63
61
env = environment .JetEngineEnvironment (env_data )
64
62
env .hf_tokenizer = tokenizer
65
63
model = fetch_models .instantiate_model_from_repo_id (FLAGS .model_id , env )
64
+ # NOTE: this is assigned later because, the model should be constructed
65
+ # as a float model first then quantized
66
+ env .quant_config = quant_config
66
67
if quant_config .enable_weight_quantization :
67
68
quantize_model .quantize_model (model , quant_config )
68
- print ("====== model =======" )
69
- print (model )
70
-
71
69
weight_shardings = model .get_sharding_annotations ()
72
70
sharded_weights = shard_weights (env , model .state_dict (), weight_shardings )
73
71
env_data .quant_config = quant_config
@@ -202,7 +200,7 @@ def interactive():
202
200
"<s>[INST] <<SYS>>\n You are an AI assistant. You will be given a task. You must generate a detailed and long answer.\n <</SYS>>\n \n Continue the following story.\n \n Kay didn't have shoes that fit her feet properly. She only wore sneakers, because the \n Choose from: [I] shoes fitted badly. [II] sneakers fitted badly. [/INST]" ,
203
201
]
204
202
for prompt in prompts :
205
- slot = random .randint (0 , FLAGS .batch_size - 1 )
203
+ slot = random .randint (0 , FLAGS .override_batch_size - 1 )
206
204
tokens , true_length = tokenizer .encode (prompt )
207
205
208
206
print (f"---- Input prompts are: { prompt } " )
@@ -330,10 +328,10 @@ def benchmark_offline():
330
328
decode_time_ms = sum (dec_times [2 :]) * 1000 / 8
331
329
332
330
largest_prefill = max (prefill_times .items ())
333
- print ("MAX tokens:" , FLAGS .batch_size / avg_decode_times )
331
+ print ("MAX tokens:" , FLAGS .override_batch_size / avg_decode_times )
334
332
335
- time2 = (FLAGS .batch_size * FLAGS .max_decode_length ) / (
336
- FLAGS .batch_size * largest_prefill [1 ]
333
+ time2 = (FLAGS .override_batch_size * FLAGS .max_decode_length ) / (
334
+ FLAGS .override_batch_size * largest_prefill [1 ]
337
335
+ FLAGS .max_decode_length * avg_decode_times
338
336
)
339
337
print ("MAX tokens 2:" , time2 )
@@ -351,6 +349,8 @@ def main():
351
349
352
350
def main_real (argv ):
353
351
"""Entry point"""
352
+ jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
353
+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
354
354
if len (argv ) < 2 :
355
355
print ("Invalid arguments. please specify 'list' or 'serve'" )
356
356
0 commit comments