Skip to content

Commit 52ec00f

Browse files
authored
Integrate disaggregated serving with JetStream (#117)
* add diaggregated server with ray support * add run_server wity ray * format
1 parent 7f6e45f commit 52ec00f

File tree

3 files changed

+79
-24
lines changed

3 files changed

+79
-24
lines changed

jetstream_pt/ray_engine.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Any, Iterable, Optional, Union
2+
from typing import Any, Iterable, Optional, Union, Tuple, List
33

44
import numpy as np
55
import ray
@@ -180,7 +180,9 @@ def create_pytorch_ray_engine(
180180
decode_pod_slice_name: str = None,
181181
enable_jax_profiler: bool = False,
182182
jax_profiler_port: int = 9999,
183-
) -> Any:
183+
) -> Union[
184+
PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]
185+
]:
184186

185187
# Return tuple as reponse: issues/107
186188
supported_models = ["llama-2", "llama-3", "gemma"]
@@ -254,4 +256,4 @@ def create_pytorch_ray_engine(
254256
is_disaggregated=is_disaggregated,
255257
pod_slice_name=decode_pod_slice_name,
256258
)
257-
return (prefill_engine, decode_engine)
259+
return ([prefill_engine], [decode_engine])

run_interactive_disaggregated.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -94,25 +94,27 @@ def create_disaggregated_engines():
9494
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
9595

9696
start = time.perf_counter()
97-
prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine(
98-
model_name=_MODEL_NAME.value,
99-
tokenizer_path=_TOKENIZER_PATH.value,
100-
ckpt_path=_CKPT_PATH.value,
101-
bf16_enable=True,
102-
param_size=_SIZE.value,
103-
context_length=_CONTEXT_LENGTH.value,
104-
batch_size=_BATCH_SIZE.value,
105-
quantize_weights=_QUANTIZE_WEIGHTS.value,
106-
quantize_kv=_QUANTIZE_KV_CACHE.value,
107-
max_cache_length=_MAX_CACHE_LENGTH.value,
108-
sharding_config=_SHARDING_CONFIG.value,
109-
is_disaggregated=_IS_DISAGGREGATED.value,
110-
num_hosts=_NUM_HOSTS.value,
111-
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
97+
prefill_engine_list, decode_engine_list = (
98+
ray_engine.create_pytorch_ray_engine(
99+
model_name=_MODEL_NAME.value,
100+
tokenizer_path=_TOKENIZER_PATH.value,
101+
ckpt_path=_CKPT_PATH.value,
102+
bf16_enable=True,
103+
param_size=_SIZE.value,
104+
context_length=_CONTEXT_LENGTH.value,
105+
batch_size=_BATCH_SIZE.value,
106+
quantize_weights=_QUANTIZE_WEIGHTS.value,
107+
quantize_kv=_QUANTIZE_KV_CACHE.value,
108+
max_cache_length=_MAX_CACHE_LENGTH.value,
109+
sharding_config=_SHARDING_CONFIG.value,
110+
is_disaggregated=_IS_DISAGGREGATED.value,
111+
num_hosts=_NUM_HOSTS.value,
112+
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
113+
)
112114
)
113115

114116
print("Initialize engine", time.perf_counter() - start)
115-
return (prefill_engine, decode_engine)
117+
return (prefill_engine_list[0], decode_engine_list[0])
116118

117119

118120
# pylint: disable-next=all

run_server_with_ray.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@
3737
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
3838
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")
3939

40+
flags.DEFINE_bool(
41+
"is_disaggregated", False, "Disaggregated serving if it's True"
42+
)
43+
44+
flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)
45+
46+
flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")
47+
4048

4149
def create_engine():
4250
"""create a pytorch engine"""
@@ -64,6 +72,37 @@ def create_engine():
6472
return engine
6573

6674

75+
def create_disaggregated_engine():
76+
"""create a pytorch engine"""
77+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
78+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
79+
80+
start = time.perf_counter()
81+
prefill_engine_list, decode_engine_list = (
82+
ray_engine.create_pytorch_ray_engine(
83+
model_name=FLAGS.model_name,
84+
tokenizer_path=FLAGS.tokenizer_path,
85+
ckpt_path=FLAGS.checkpoint_path,
86+
bf16_enable=FLAGS.bf16_enable,
87+
param_size=FLAGS.size,
88+
context_length=FLAGS.context_length,
89+
batch_size=FLAGS.batch_size,
90+
quantize_weights=FLAGS.quantize_weights,
91+
quantize_kv=FLAGS.quantize_kv_cache,
92+
max_cache_length=FLAGS.max_cache_length,
93+
sharding_config=FLAGS.sharding_config,
94+
enable_jax_profiler=FLAGS.enable_jax_profiler,
95+
jax_profiler_port=FLAGS.jax_profiler_port,
96+
is_disaggregated=FLAGS.is_disaggregated,
97+
num_hosts=FLAGS.num_hosts,
98+
decode_pod_slice_name=FLAGS.decode_pod_slice_name,
99+
)
100+
)
101+
102+
print("Initialize engine", time.perf_counter() - start)
103+
return (prefill_engine_list, decode_engine_list)
104+
105+
67106
# pylint: disable-next=all
68107
def main(argv: Sequence[str]):
69108
del argv
@@ -74,12 +113,24 @@ def main(argv: Sequence[str]):
74113

75114
print(f"devices: {devices}")
76115

77-
engine = create_engine()
116+
if FLAGS.is_disaggregated:
117+
prefill_engine_list, decode_engine_list = create_disaggregated_engine()
118+
chips = int(len(devices) / 2)
119+
server_config = ServerConfig(
120+
prefill_slices=(f"tpu={chips}",),
121+
prefill_engine_create_fns=(lambda a: prefill_engine_list[0],),
122+
generate_slices=(f"tpu={chips}",),
123+
generate_engine_create_fns=(lambda a: decode_engine_list[0],),
124+
is_ray_backend=True,
125+
)
126+
127+
else:
128+
engine = create_engine()
129+
server_config = ServerConfig(
130+
interleaved_slices=(f"tpu={len(devices)}",),
131+
interleaved_engine_create_fns=(lambda a: engine,),
132+
)
78133

79-
server_config = ServerConfig(
80-
interleaved_slices=(f"tpu={len(devices)}",),
81-
interleaved_engine_create_fns=(lambda a: engine,),
82-
)
83134
print(f"server_config: {server_config}")
84135

85136
jetstream_server = server_lib.run(

0 commit comments

Comments
 (0)