@@ -461,7 +461,7 @@ def prefill_ray(
461
461
existing_prefix : Optional [Prefix ] = None ,
462
462
padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
463
463
true_length : int ,
464
- ) -> None :
464
+ ) -> tuple [ Prefix , engine_api . ResultTokens ] :
465
465
"""Do prefill in ray worker"""
466
466
logits , updated_caches = self .prefill (
467
467
params = params ,
@@ -476,7 +476,25 @@ def prefill_ray(
476
476
prefix = Prefix (token , updated_caches , true_length )
477
477
self .prefix_queue .put (prefix , block = False )
478
478
479
- return token
479
+ token_out = jnp .reshape (token , (1 , 1 ))
480
+ data = jnp .concatenate (
481
+ [
482
+ token_out , # First token
483
+ jnp .ones_like (token_out ), # validity of first token
484
+ jnp .zeros ((1 , 1 ), dtype = jnp .int32 ), # length = 0
485
+ ],
486
+ axis = - 1 ,
487
+ )
488
+ length = token_out .shape [1 ]
489
+ result = engine_api .ResultTokens (
490
+ data = data ,
491
+ tokens_idx = (0 , length ),
492
+ valid_idx = (length , 2 * length ),
493
+ length_idx = (2 * length , 2 * length + 1 ),
494
+ samples_per_slot = 1 ,
495
+ )
496
+
497
+ return prefix , result
480
498
481
499
def _convert_to_np_caches (
482
500
self , caches : List [Tuple [jax .Array , jax .Array ]]
@@ -495,7 +513,7 @@ def prefill_ray_disaggregation(
495
513
existing_prefix : Optional [Prefix ] = None ,
496
514
padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
497
515
true_length : int ,
498
- ) -> Any :
516
+ ) -> tuple [ NpPrefix , engine_api . ResultTokens ] :
499
517
"""Do prefill in ray worker"""
500
518
logits , updated_caches = self .prefill (
501
519
params = params ,
@@ -513,7 +531,25 @@ def prefill_ray_disaggregation(
513
531
np_update_caches = self ._convert_to_np_caches (updated_caches )
514
532
np_prefix = NpPrefix (token , np_update_caches , true_length )
515
533
516
- return np_prefix
534
+ token_out = jnp .reshape (token , (1 , 1 ))
535
+ data = jnp .concatenate (
536
+ [
537
+ token_out , # First token
538
+ jnp .ones_like (token_out ), # validity of first token
539
+ jnp .zeros ((1 , 1 ), dtype = jnp .int32 ), # length = 0
540
+ ],
541
+ axis = - 1 ,
542
+ )
543
+ length = token_out .shape [1 ]
544
+ result = engine_api .ResultTokens (
545
+ data = data ,
546
+ tokens_idx = (0 , length ),
547
+ valid_idx = (length , 2 * length ),
548
+ length_idx = (2 * length , 2 * length + 1 ),
549
+ samples_per_slot = 1 ,
550
+ )
551
+
552
+ return np_prefix , result
517
553
518
554
def transfer (self , np_prefix : NpPrefix ) -> Any :
519
555
"""Transfer prefill result from object store to HBM"""
0 commit comments