File tree Expand file tree Collapse file tree 2 files changed +3
-6
lines changed Expand file tree Collapse file tree 2 files changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -113,7 +113,7 @@ def prefill_impl(
113
113
results = ray .get (all_outputs )
114
114
# The prefill function does not return any values;
115
115
# the worker itself manages and maintains the prefill states.
116
- return results [0 ]
116
+ return None , results [0 ]
117
117
118
118
def transfer (self , np_prefix : NpPrefix ) -> Any :
119
119
"""Store prefill result into object store, then transfer to decode engine workers."""
Original file line number Diff line number Diff line change @@ -454,7 +454,7 @@ def prefill_ray(
454
454
existing_prefix : Optional [Prefix ] = None ,
455
455
padded_tokens : PrefillInputs , # PrefillInputs[np.ndarray],
456
456
true_length : int ,
457
- ) -> tuple [ Prefix , engine_api .ResultTokens ] :
457
+ ) -> engine_api .ResultTokens :
458
458
"""Do prefill in ray worker"""
459
459
logits , updated_caches = self .prefill (
460
460
params = params ,
@@ -466,9 +466,6 @@ def prefill_ray(
466
466
logits = logits [0 ]
467
467
468
468
token = np .argmax (logits [true_length - 1 ])
469
- updated_caches = multihost_utils .process_allgather (
470
- updated_caches , tiled = True
471
- )
472
469
prefix = Prefix (token , updated_caches , true_length )
473
470
self .prefix_queue .put (prefix , block = False )
474
471
@@ -490,7 +487,7 @@ def prefill_ray(
490
487
samples_per_slot = 1 ,
491
488
)
492
489
493
- return prefix , result
490
+ return result
494
491
495
492
def _convert_to_np_caches (
496
493
self , caches : List [Tuple [jax .Array , jax .Array ]]
You can’t perform that action at this time.
0 commit comments