We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 98a8e28 commit 508d71cCopy full SHA for 508d71c
jetstream_pt/ray_engine.py
@@ -4,6 +4,7 @@
4
5
import numpy as np
6
import ray
7
+from ray.runtime_env import RuntimeEnv
8
from ray.util.accelerators import tpu
9
10
from jetstream.engine import engine_api, tokenizer_pb2
@@ -241,7 +242,8 @@ def create_pytorch_ray_engine(
241
242
), f"num_hosts (current value {num_hosts}) should be a positive number"
243
# pylint: disable-next=all
244
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
- resources={"TPU": 4}
245
+ resources={"TPU": 4},
246
+ runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}),
247
)
248
engine_workers = []
249
for _ in range(num_hosts):
0 commit comments