27
27
flags .DEFINE_integer ("max_output_length" , 1024 , "The batch size" )
28
28
flags .DEFINE_integer ("port" , 9000 , "port to listen on" )
29
29
flags .DEFINE_integer ("threads" , 64 , "number of worker threads in thread pool" )
30
- flags .DEFINE_string ("benchmark_save_offline_result_to_file" , "" , "if set, then save the result to the given file name" )
30
+ flags .DEFINE_string (
31
+ "benchmark_save_offline_result_to_file" ,
32
+ "" ,
33
+ "if set, then save the result to the given file name" ,
34
+ )
31
35
32
36
33
37
def shard_weights (env , weights , weight_shardings ):
@@ -115,21 +119,22 @@ def _check_model_id():
115
119
list_model ()
116
120
sys .exit (1 )
117
121
118
- def _run_prefill_time (engine , params , decode_state , seqlen , profiler_started ):
122
+
123
+ def _run_prefill_time (pt_engine , params , decode_state , seqlen , profiler_started ):
119
124
"""Run prefill and measure time."""
120
- metadata = engine .get_tokenizer ()
121
- tokenizer = engine .build_tokenizer (metadata )
125
+ metadata = pt_engine .get_tokenizer ()
126
+ tokenizer = pt_engine .build_tokenizer (metadata )
122
127
123
128
text = "This is a beautiful day"
124
129
tokens , true_length = tokenizer .encode (
125
130
text , is_bos = True , prefill_lengths = [seqlen ]
126
131
)
127
132
128
133
for _ in range (3 ):
129
- prefill_result , _ = engine .prefill (
134
+ prefill_result , _ = pt_engine .prefill (
130
135
params = params , padded_tokens = tokens , true_length = true_length
131
136
)
132
- decode_state = engine .insert (
137
+ decode_state = pt_engine .insert (
133
138
prefill_result , decode_state , slot = jnp .int32 (1 )
134
139
)
135
140
@@ -140,10 +145,10 @@ def _run_prefill_time(engine, params, decode_state, seqlen, profiler_started):
140
145
jax .profiler .start_trace (FLAGS .profiling_output )
141
146
profiler_started = True
142
147
143
- prefill_result , _ = engine .prefill (
148
+ prefill_result , _ = pt_engine .prefill (
144
149
params = params , padded_tokens = tokens , true_length = true_length
145
150
)
146
- decode_state = engine .insert (
151
+ decode_state = pt_engine .insert (
147
152
prefill_result , decode_state , slot = jnp .int32 (i )
148
153
)
149
154
jax .block_until_ready (decode_state )
@@ -244,25 +249,28 @@ def interactive():
244
249
print ("---- All output text." )
245
250
print (tokenizer .decode (sampled_tokens_list ))
246
251
252
+
247
253
def _save_benchmark_to_file (filename , prefill_times_ms , decode_time_ms ):
248
- lines = [
249
- " # Offline benchmark numbers" ,
250
- " ## Model: " + FLAGS .model_id ,
251
- " ## Batch size: {}" .format (FLAGS .override_batch_size ),
252
- " ## Quantize: {}" .format (FLAGS .quantize_weights ),
253
- " | | time (ms) |" ,
254
- " |-------|-----------|" ,
255
- ] + [
256
- "| Prefill {} | {} |" .format (x , y ) for x , y in prefill_times_ms .items ()
257
- ] + [
258
- "| Decode | {} |" .format (decode_time_ms )
259
- ]
260
- with open (filename , 'w' ) as f :
261
- f .write ('\n ' .join (lines ))
254
+ lines = (
255
+ [
256
+ " # Offline benchmark numbers" ,
257
+ " ## Model: " + FLAGS .model_id ,
258
+ f" ## Batch size: { FLAGS .override_batch_size } " ,
259
+ f" ## Quantize: { FLAGS .quantize_weights } " ,
260
+ " | | time (ms) |" ,
261
+ " |-------|-----------|" ,
262
+ ]
263
+ + [
264
+ f"| Prefill { x } | { y } |"
265
+ for x , y in prefill_times_ms .items ()
266
+ ]
267
+ + [f"| Decode | { decode_time_ms } |" ]
268
+ )
269
+ with open (filename , "w" , encoding = 'utf-8' ) as f :
270
+ f .write ("\n " .join (lines ))
262
271
f .flush ()
263
272
264
273
265
-
266
274
def benchmark_offline ():
267
275
"""function to run engine offline."""
268
276
_check_model_id ()
@@ -280,7 +288,7 @@ def benchmark_offline():
280
288
profiler_started = False
281
289
# 16 .. 1024
282
290
for exp in range (4 , 11 ):
283
- batch = 2 ** exp
291
+ batch = 2 ** exp
284
292
runtime , decode_state , profiler_started = _run_prefill_time (
285
293
pt_engine , params , decode_state , batch , profiler_started
286
294
)
@@ -333,13 +341,12 @@ def benchmark_offline():
333
341
334
342
if FLAGS .benchmark_save_offline_result_to_file :
335
343
_save_benchmark_to_file (
336
- FLAGS .benchmark_save_offline_result_to_file ,
337
- prefill_times_ms ,
338
- decode_time_ms
344
+ FLAGS .benchmark_save_offline_result_to_file ,
345
+ prefill_times_ms ,
346
+ decode_time_ms ,
339
347
)
340
348
341
349
342
-
343
350
def main ():
344
351
"""Main function."""
345
352
0 commit comments