Skip to content

Commit 94449c3

Browse files
authored
Make Ray engine and worker process prefill returning first token (#147)
* result_token fix * fix return * Format
1 parent 663c102 commit 94449c3

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

jetstream_pt/ray_engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def prefill(
7070
existing_prefix: Optional[Prefix] = None,
7171
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
7272
true_length: int,
73-
) -> Prefix:
73+
) -> Tuple[Prefix, engine_api.ResultTokens]:
7474
if self.is_disaggregated:
7575
return self.prefill_impl(
7676
params=params,
@@ -95,7 +95,7 @@ def prefill_impl(
9595
existing_prefix: Optional[Prefix] = None,
9696
padded_tokens: np.ndarray, # PrefillInputs[np.ndarray],
9797
true_length: int,
98-
) -> Prefix:
98+
) -> Tuple[Prefix, engine_api.ResultTokens]:
9999
all_outputs = []
100100
for worker in self.engine_workers:
101101
prefill_func = (

jetstream_pt/ray_worker.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def prefill_ray(
461461
existing_prefix: Optional[Prefix] = None,
462462
padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray],
463463
true_length: int,
464-
) -> None:
464+
) -> tuple[Prefix, engine_api.ResultTokens]:
465465
"""Do prefill in ray worker"""
466466
logits, updated_caches = self.prefill(
467467
params=params,
@@ -476,7 +476,25 @@ def prefill_ray(
476476
prefix = Prefix(token, updated_caches, true_length)
477477
self.prefix_queue.put(prefix, block=False)
478478

479-
return token
479+
token_out = jnp.reshape(token, (1, 1))
480+
data = jnp.concatenate(
481+
[
482+
token_out, # First token
483+
jnp.ones_like(token_out), # validity of first token
484+
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
485+
],
486+
axis=-1,
487+
)
488+
length = token_out.shape[1]
489+
result = engine_api.ResultTokens(
490+
data=data,
491+
tokens_idx=(0, length),
492+
valid_idx=(length, 2 * length),
493+
length_idx=(2 * length, 2 * length + 1),
494+
samples_per_slot=1,
495+
)
496+
497+
return prefix, result
480498

481499
def _convert_to_np_caches(
482500
self, caches: List[Tuple[jax.Array, jax.Array]]
@@ -495,7 +513,7 @@ def prefill_ray_disaggregation(
495513
existing_prefix: Optional[Prefix] = None,
496514
padded_tokens: PrefillInputs, # PrefillInputs[np.ndarray],
497515
true_length: int,
498-
) -> Any:
516+
) -> tuple[NpPrefix, engine_api.ResultTokens]:
499517
"""Do prefill in ray worker"""
500518
logits, updated_caches = self.prefill(
501519
params=params,
@@ -513,7 +531,25 @@ def prefill_ray_disaggregation(
513531
np_update_caches = self._convert_to_np_caches(updated_caches)
514532
np_prefix = NpPrefix(token, np_update_caches, true_length)
515533

516-
return np_prefix
534+
token_out = jnp.reshape(token, (1, 1))
535+
data = jnp.concatenate(
536+
[
537+
token_out, # First token
538+
jnp.ones_like(token_out), # validity of first token
539+
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
540+
],
541+
axis=-1,
542+
)
543+
length = token_out.shape[1]
544+
result = engine_api.ResultTokens(
545+
data=data,
546+
tokens_idx=(0, length),
547+
valid_idx=(length, 2 * length),
548+
length_idx=(2 * length, 2 * length + 1),
549+
samples_per_slot=1,
550+
)
551+
552+
return np_prefix, result
517553

518554
def transfer(self, np_prefix: NpPrefix) -> Any:
519555
"""Transfer prefill result from object store to HBM"""

0 commit comments

Comments
 (0)