Skip to content

Ray engine crashes on multihost when fetching Jax.array from prefill_ray #150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
richardsliu opened this issue Jul 16, 2024 · 1 comment · Fixed by #164 or #170
Closed

Ray engine crashes on multihost when fetching Jax.array from prefill_ray #150

richardsliu opened this issue Jul 16, 2024 · 1 comment · Fixed by #164 or #170
Assignees

Comments

@richardsliu
Copy link
Collaborator

Prefill_ray() now returns a [result, first_token] tuple, where first_token contains a Jax array. This will cause a crash when attempting to fetch the Ray results remotely:

job_id:06000000
:actor_name:ServeReplica:default:JetStreamDeployment
SIGTERM handler is not set because current thread is not the main thread.
Using address example-cluster-kuberay-head-svc.default.svc.cluster.local:6379 set in the environment variable RAY_ADDRESS
Connecting to existing Ray cluster at address: example-cluster-kuberay-head-svc.default.svc.cluster.local:6379...
Calling ray.init() again after it has already been called.
normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.
Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 162, in run
    super().run()
  File "/home/ray/anaconda3/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jetstream/core/orchestrator.py", line 507, in _prefill_thread
    prefill_result, first_token = prefill_engine.prefill(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 83, in prefill
    return self.prefill_impl(
  File "/tmp/ray/session_2024-07-12_17-07-57_303234_8/runtime_resources/working_dir_files/_ray_pkg_e66f370ed8382ac2/jetstream_pt/ray_engine.py", line 113, in prefill_impl
    results = ray.get(all_outputs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2623, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 861, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=14601, ip=10.104.7.5, actor_id=0721a490262f0d248878f59d06000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7974fc14e410>)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
    cp.dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
    return super().dump(obj)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
    fun, args, arr_state = self._value.__reduce__()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
    raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
@FanhaiLu1 FanhaiLu1 self-assigned this Jul 16, 2024
@richardsliu
Copy link
Collaborator Author

@FanhaiLu1 Still seeing the same error after the latest fix, can you take a look?

 return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 2656, in get
 values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/worker.py", line 871, in get_objects
 raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): �[36mray::PyTorchRayWorker.prefill_ray()�[39m (pid=904, ip=10.36.8.6, actor_id=7744b1669cd8fdc6809a72a502000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7f5e3259e380>)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
 cp.dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
 return super().dump(obj)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 449, in __reduce__
 fun, args, arr_state = self._value.__reduce__()
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
 return func(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
 raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants