Skip to content

Commit 8675c30

Browse files
sixiang-googlewang2yn84
authored andcommitted
Make prefilling return first token for loadgen integration (#143)
* Make prefilling return first token for loadgen integration * minor fix and lint * enable passing of max_decode_length as a flag
1 parent 17ab200 commit 8675c30

File tree

9 files changed

+36
-18
lines changed

9 files changed

+36
-18
lines changed

benchmarks/prefill_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def prefill_benchmark(tokens_list, engine, params, warmup):
8282
# pylint: disable-next=all
8383
warmup_text = "warmup" if warmup else "execute"
8484
it = time.time()
85-
prefill_result = engine.prefill(
85+
prefill_result, _ = engine.prefill(
8686
params=params,
8787
padded_tokens=prefill_tokens,
8888
true_length=len(prefill_tokens),

benchmarks/run_offline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
4444
)
4545

4646
for _ in range(3):
47-
prefill_result = engine.prefill(
47+
prefill_result, _ = engine.prefill(
4848
params=params, padded_tokens=tokens, true_length=true_length
4949
)
5050
decode_state = engine.insert(
@@ -58,7 +58,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
5858
jax.profiler.start_trace(FLAGS.profiling_output)
5959
profiler_started = True
6060

61-
prefill_result = engine.prefill(
61+
prefill_result, _ = engine.prefill(
6262
params=params, padded_tokens=tokens, true_length=true_length
6363
)
6464
decode_state = engine.insert(

deps/JetStream

jetstream_pt/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
flags.DEFINE_string("size", "tiny", "size of model")
3232
flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize")
3333
flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
34+
flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
3435
flags.DEFINE_string("sharding_config", "", "config file for sharding")
3536
flags.DEFINE_bool(
3637
"shard_on_batch",
@@ -197,6 +198,7 @@ def create_engine_from_config_flags():
197198
batch_size=FLAGS.batch_size,
198199
quant_config=quant_config,
199200
max_cache_length=FLAGS.max_cache_length,
201+
max_decode_length=FLAGS.max_decode_length,
200202
sharding_config=sharding_file_name,
201203
shard_on_batch=FLAGS.shard_on_batch,
202204
ragged_mha=FLAGS.ragged_mha,

jetstream_pt/engine.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(
9797
jax.config.update("jax_enable_x64", False)
9898

9999
self.prefill = jax.jit(
100-
self.prefill, out_shardings=self.get_prefix_destination_sharding()
100+
self.prefill,
101+
out_shardings=(self.get_prefix_destination_sharding(), None),
101102
)
102103
self.insert = jax.jit(
103104
self.insert,
@@ -247,7 +248,7 @@ def prefill(
247248
existing_prefix: Optional[Prefix] = None,
248249
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
249250
true_length: int,
250-
) -> Prefix:
251+
) -> Tuple[Prefix, engine_api.ResultTokens]:
251252
if isinstance(padded_tokens, jax.Array):
252253
batched_token = padded_tokens.reshape(1, -1)
253254
else:
@@ -264,7 +265,6 @@ def prefill(
264265
)
265266
if len(logits.shape) == 3: # b, seqlen, num words
266267
logits = logits[0] # seqlen, num words
267-
268268
token = sampling_utils.sampling(
269269
logits[true_length - 1],
270270
self.rng,
@@ -273,7 +273,23 @@ def prefill(
273273
self.env.nucleus_topp,
274274
self.env.temperature,
275275
)
276-
276+
token_out = jnp.reshape(token, (1, 1))
277+
data = jnp.concatenate(
278+
[
279+
token_out, # First token
280+
jnp.ones_like(token_out), # validity of first token
281+
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
282+
],
283+
axis=-1,
284+
)
285+
length = token_out.shape[1]
286+
result = engine_api.ResultTokens(
287+
data=data,
288+
tokens_idx=(0, length),
289+
valid_idx=(length, 2 * length),
290+
length_idx=(2 * length, 2 * length + 1),
291+
samples_per_slot=1,
292+
)
277293
# truncate to true_length didnt work need to be out side of jit
278294
# caches = [
279295
# (jax.lax.dynamic_slice_in_dim(
@@ -282,7 +298,7 @@ def prefill(
282298
# v, seq_len - true_length, true_length, axis=2))
283299
# for k, v in updated_caches
284300
# ]
285-
return Prefix(token, updated_caches, true_length)
301+
return Prefix(token, updated_caches, true_length), result
286302

287303
def shrink_prefix(
288304
self,

run_interactive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def main(argv):
6666
# pylint: disable-next=all
6767
if profiling_prefill:
6868
jax.profiler.start_trace(profiling_output)
69-
prefill_result = engine.prefill(
69+
prefill_result, _ = engine.prefill(
7070
params=params, padded_tokens=tokens, true_length=true_length
7171
)
7272
# pylint: disable-next=all

run_interactive_disaggregated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def main(argv):
161161
print(
162162
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
163163
)
164-
prefill_result = prefill_engine.prefill(
164+
prefill_result, _ = prefill_engine.prefill(
165165
params=None, padded_tokens=tokens, true_length=true_length
166166
)
167167
print(

run_interactive_multiple_host.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(argv):
8888
print(f"---- Encoded tokens are: {tokens}")
8989

9090
# pylint: disable-next=all
91-
prefill_result = engine.prefill(
91+
prefill_result, _ = engine.prefill(
9292
params=None, padded_tokens=tokens, true_length=true_length
9393
)
9494
# pylint: disable-next=all

tests/test_llama_e2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_jetstream_llama2_seed(self):
128128
decode_state = engine.init_decode_state()
129129
slot = 0
130130
# pylint: disable-next=all
131-
prefill_result = engine.prefill(
131+
prefill_result, _ = engine.prefill(
132132
params=params, padded_tokens=padded_tokens, true_length=true_length
133133
)
134134

@@ -197,7 +197,7 @@ def _llama_e2e(self, env, model_arg):
197197
decode_state = engine.init_decode_state()
198198
slot = 0
199199
# pylint: disable-next=all
200-
prefill_result = engine.prefill(
200+
prefill_result, _ = engine.prefill(
201201
params=params, padded_tokens=padded_tokens, true_length=true_length
202202
)
203203

@@ -334,7 +334,7 @@ def test_llama_e2e_two_addtional_tokens(self):
334334
slot = 0
335335

336336
# pylint: disable-next=all
337-
prefill_result = engine.prefill(
337+
prefill_result, _ = engine.prefill(
338338
params=params, padded_tokens=padded_tokens, true_length=true_length
339339
)
340340

@@ -406,7 +406,7 @@ def test_llama_e2e_four_addtional_tokens(self):
406406
slot = 0
407407

408408
# pylint: disable-next=all
409-
prefill_result = engine.prefill(
409+
prefill_result, _ = engine.prefill(
410410
params=params, padded_tokens=padded_tokens, true_length=true_length
411411
)
412412

@@ -472,7 +472,7 @@ def test_llama_with_original_prefill_decode_32(self):
472472
# pylint: disable-next=all
473473
decode_state = engine.init_decode_state()
474474
# pylint: disable-next=all
475-
prefill_result = engine.prefill(
475+
prefill_result, _ = engine.prefill(
476476
params=params, padded_tokens=padded_tokens, true_length=true_length
477477
)
478478
out_tokens = prefill_result.token
@@ -547,7 +547,7 @@ def test_llama_with_original_prefill_decode(self):
547547
# pylint: disable-next=all
548548
decode_state = engine.init_decode_state()
549549
# pylint: disable-next=all
550-
prefill_result = engine.prefill(
550+
prefill_result, _ = engine.prefill(
551551
params=params, padded_tokens=padded_tokens, true_length=true_length
552552
)
553553
out_tokens = prefill_result.token

0 commit comments

Comments
 (0)