You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Prefill_ray() now returns a [result, first_token] tuple, where first_token contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:
job_id:06000000
:actor_name:ServeReplica:default:JetStreamDeployment
SIGTERM handler is not set because current thread is not the main thread.
Using address example-cluster-kuberay-head-svc.default.svc.cluster.local:6379 set in the environment variable RAY_ADDRESS
Connecting to existing Ray cluster at address: example-cluster-kuberay-head-svc.default.svc.cluster.local:6379...
Calling ray.init() again after it has already been called.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Traceback (most recent call last):
File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 162, in run
super().run()
File "/home/ray/anaconda3/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 507, in _prefill_thread
prefill_result, first_token = prefill_engine.prefill(
File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 83, in prefill
return self.prefill_impl(
File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 113, in prefill_impl
results = ray.get(all_outputs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2623, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=14601, ip=10.104.7.5, actor_id=0721a490262f0d248878f59d06000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7974fc14e410>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
The text was updated successfully, but these errors were encountered:
@FanhaiLu1 Still seeing the same error after the latest fix, can you take a look?
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2656, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=904, ip=10.36.8.6, actor_id=7744b1669cd8fdc6809a72a502000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7f5e3259e380>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
Prefill_ray() now returns a
[result, first_token]
tuple, wherefirst_token
contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:The text was updated successfully, but these errors were encountered: