Skip to content

Commit 50a6d10

Browse files
Make prefilling return first token for loadgen integration (#143)
* Make prefilling return first token for loadgen integration * minor fix and lint * enable passing of max_decode_length as a flag
1 parent 9717eb9 commit 50a6d10

File tree

9 files changed

+36
-18
lines changed

9 files changed

+36
-18
lines changed

benchmarks/prefill_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def prefill_benchmark(tokens_list, engine, params, warmup):
8282
# pylint: disable-next=all
8383
warmup_text = "warmup" if warmup else "execute"
8484
it = time.time()
85-
prefill_result = engine.prefill(
85+
prefill_result, _ = engine.prefill(
8686
params=params,
8787
padded_tokens=prefill_tokens,
8888
true_length=len(prefill_tokens),

benchmarks/run_offline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
4343
)
4444

4545
for _ in range(3):
46-
prefill_result = engine.prefill(
46+
prefill_result, _ = engine.prefill(
4747
params=params, padded_tokens=tokens, true_length=true_length
4848
)
4949
decode_state = engine.insert(
@@ -53,7 +53,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
5353
nums = 5
5454
start = time.perf_counter()
5555
for i in range(nums):
56-
prefill_result = engine.prefill(
56+
prefill_result, _ = engine.prefill(
5757
params=params, padded_tokens=tokens, true_length=true_length
5858
)
5959
decode_state = engine.insert(

deps/JetStream

jetstream_pt/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
flags.DEFINE_string("size", "tiny", "size of model")
3232
flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize")
3333
flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
34+
flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
3435
flags.DEFINE_string("sharding_config", "", "config file for sharding")
3536
flags.DEFINE_bool(
3637
"shard_on_batch",
@@ -173,6 +174,7 @@ def create_engine_from_config_flags():
173174
batch_size=FLAGS.batch_size,
174175
quant_config=quant_config,
175176
max_cache_length=FLAGS.max_cache_length,
177+
max_decode_length=FLAGS.max_decode_length,
176178
sharding_config=sharding_file_name,
177179
shard_on_batch=FLAGS.shard_on_batch,
178180
ragged_mha=FLAGS.ragged_mha,

jetstream_pt/engine.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ def __init__(
9797
jax.config.update("jax_enable_x64", False)
9898

9999
self.prefill = jax.jit(
100-
self.prefill, out_shardings=self.get_prefix_destination_sharding()
100+
self.prefill,
101+
out_shardings=(self.get_prefix_destination_sharding(), None),
101102
)
102103
self.insert = jax.jit(
103104
self.insert,
@@ -243,7 +244,7 @@ def prefill(
243244
existing_prefix: Optional[Prefix] = None,
244245
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
245246
true_length: int,
246-
) -> Prefix:
247+
) -> Tuple[Prefix, engine_api.ResultTokens]:
247248
if isinstance(padded_tokens, jax.Array):
248249
batched_token = padded_tokens.reshape(1, -1)
249250
else:
@@ -260,7 +261,6 @@ def prefill(
260261
)
261262
if len(logits.shape) == 3: # b, seqlen, num words
262263
logits = logits[0] # seqlen, num words
263-
264264
token = sampling_utils.sampling(
265265
logits[true_length - 1],
266266
self.rng,
@@ -269,7 +269,23 @@ def prefill(
269269
self.env.nucleus_topp,
270270
self.env.temperature,
271271
)
272-
272+
token_out = jnp.reshape(token, (1, 1))
273+
data = jnp.concatenate(
274+
[
275+
token_out, # First token
276+
jnp.ones_like(token_out), # validity of first token
277+
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
278+
],
279+
axis=-1,
280+
)
281+
length = token_out.shape[1]
282+
result = engine_api.ResultTokens(
283+
data=data,
284+
tokens_idx=(0, length),
285+
valid_idx=(length, 2 * length),
286+
length_idx=(2 * length, 2 * length + 1),
287+
samples_per_slot=1,
288+
)
273289
# truncate to true_length didnt work need to be out side of jit
274290
# caches = [
275291
# (jax.lax.dynamic_slice_in_dim(
@@ -278,7 +294,7 @@ def prefill(
278294
# v, seq_len - true_length, true_length, axis=2))
279295
# for k, v in updated_caches
280296
# ]
281-
return Prefix(token, updated_caches, true_length)
297+
return Prefix(token, updated_caches, true_length), result
282298

283299
def shrink_prefix(
284300
self,

run_interactive.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def main(argv):
6262
print(f"---- Encoded tokens are: {tokens}")
6363

6464
# pylint: disable-next=all
65-
prefill_result = engine.prefill(
65+
prefill_result, _ = engine.prefill(
6666
params=params, padded_tokens=tokens, true_length=true_length
6767
)
6868
# pylint: disable-next=all

run_interactive_disaggregated.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def main(argv):
161161
print(
162162
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
163163
)
164-
prefill_result = prefill_engine.prefill(
164+
prefill_result, _ = prefill_engine.prefill(
165165
params=None, padded_tokens=tokens, true_length=true_length
166166
)
167167
print(

run_interactive_multiple_host.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def main(argv):
8888
print(f"---- Encoded tokens are: {tokens}")
8989

9090
# pylint: disable-next=all
91-
prefill_result = engine.prefill(
91+
prefill_result, _ = engine.prefill(
9292
params=None, padded_tokens=tokens, true_length=true_length
9393
)
9494
# pylint: disable-next=all

tests/test_llama_e2e.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_jetstream_llama2_seed(self):
127127
decode_state = engine.init_decode_state()
128128
slot = 0
129129
# pylint: disable-next=all
130-
prefill_result = engine.prefill(
130+
prefill_result, _ = engine.prefill(
131131
params=params, padded_tokens=padded_tokens, true_length=true_length
132132
)
133133

@@ -193,7 +193,7 @@ def _llama_e2e(self, env, model_arg):
193193
decode_state = engine.init_decode_state()
194194
slot = 0
195195
# pylint: disable-next=all
196-
prefill_result = engine.prefill(
196+
prefill_result, _ = engine.prefill(
197197
params=params, padded_tokens=padded_tokens, true_length=true_length
198198
)
199199

@@ -278,7 +278,7 @@ def test_llama_e2e_two_addtional_tokens(self):
278278
slot = 0
279279

280280
# pylint: disable-next=all
281-
prefill_result = engine.prefill(
281+
prefill_result, _ = engine.prefill(
282282
params=params, padded_tokens=padded_tokens, true_length=true_length
283283
)
284284

@@ -350,7 +350,7 @@ def test_llama_e2e_four_addtional_tokens(self):
350350
slot = 0
351351

352352
# pylint: disable-next=all
353-
prefill_result = engine.prefill(
353+
prefill_result, _ = engine.prefill(
354354
params=params, padded_tokens=padded_tokens, true_length=true_length
355355
)
356356

@@ -416,7 +416,7 @@ def test_llama_with_original_prefill_decode_32(self):
416416
# pylint: disable-next=all
417417
decode_state = engine.init_decode_state()
418418
# pylint: disable-next=all
419-
prefill_result = engine.prefill(
419+
prefill_result, _ = engine.prefill(
420420
params=params, padded_tokens=padded_tokens, true_length=true_length
421421
)
422422
out_tokens = prefill_result.token
@@ -491,7 +491,7 @@ def test_llama_with_original_prefill_decode(self):
491491
# pylint: disable-next=all
492492
decode_state = engine.init_decode_state()
493493
# pylint: disable-next=all
494-
prefill_result = engine.prefill(
494+
prefill_result, _ = engine.prefill(
495495
params=params, padded_tokens=padded_tokens, true_length=true_length
496496
)
497497
out_tokens = prefill_result.token

0 commit comments

Comments
 (0)