From 5490ec47c1c426468c5611aee4f933fe0eac94f2 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 5 Jun 2024 22:33:24 +0000 Subject: [PATCH 1/3] add diaggregated server with ray support --- jetstream_pt/ray_engine.py | 6 ++-- run_interactive_disaggregated.py | 4 +-- run_server_with_ray.py | 57 ++++++++++++++++++++++++++++++-- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 13d11edc..f4f05329 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union, Tuple, List import numpy as np import ray @@ -180,7 +180,7 @@ def create_pytorch_ray_engine( decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, -) -> Any: +) -> Union[PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]]: # Return tuple as reponse: issues/107 supported_models = ["llama-2", "llama-3", "gemma"] @@ -254,4 +254,4 @@ def create_pytorch_ray_engine( is_disaggregated=is_disaggregated, pod_slice_name=decode_pod_slice_name, ) - return (prefill_engine, decode_engine) + return ([prefill_engine], [decode_engine]) diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index b086d365..12749fbb 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -94,7 +94,7 @@ def create_disaggregated_engines(): os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" start = time.perf_counter() - prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine( + prefill_engine_list, decode_engine_list = ray_engine.create_pytorch_ray_engine( model_name=_MODEL_NAME.value, tokenizer_path=_TOKENIZER_PATH.value, ckpt_path=_CKPT_PATH.value, @@ -112,7 +112,7 @@ def create_disaggregated_engines(): ) print("Initialize engine", time.perf_counter() - start) - return (prefill_engine, decode_engine) + return (prefill_engine_list[0], decode_engine_list[0]) # pylint: disable-next=all diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 325bc108..2059ce91 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -37,6 +37,18 @@ flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") +flags.DEFINE_bool( + "is_disaggregated", False, "Disaggregated serving if it's True" +) + +flags.DEFINE_integer( + "num_hosts", 4, "Number of TPU host", required=False +) + +flags.DEFINE_string( + "decode_pod_slice_name", "", "Decode pod slice name" +) + def create_engine(): """create a pytorch engine""" @@ -63,6 +75,34 @@ def create_engine(): print("Initialize engine", time.perf_counter() - start) return engine +def create_disaggregated_engine(): + """create a pytorch engine""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + start = time.perf_counter() + prefill_engine_list, decode_engine_list = ray_engine.create_pytorch_ray_engine( + model_name=FLAGS.model_name, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + enable_jax_profiler=FLAGS.enable_jax_profiler, + jax_profiler_port=FLAGS.jax_profiler_port, + is_disaggregated=FLAGS.is_disaggregated, + num_hosts=FLAGS.num_hosts, + decode_pod_slice_name=FLAGS.decode_pod_slice_name, + ) + + print("Initialize engine", time.perf_counter() - start) + return (prefill_engine_list, decode_engine_list) + # pylint: disable-next=all def main(argv: Sequence[str]): @@ -74,12 +114,23 @@ def main(argv: Sequence[str]): print(f"devices: {devices}") - engine = create_engine() - server_config = ServerConfig( + if FLAGS.is_disaggregated: + prefill_engine_list, decode_engine_list = create_disaggregated_engine() + server_config = ServerConfig( + prefill_slices=(f"tpu={len(devices)}",), + prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), + generate_slices=(f"tpu={len(devices)}",), + generate_engine_create_fns=(lambda a: decode_engine_list[0],), + ) + + else: + engine = create_engine() + server_config = ServerConfig( interleaved_slices=(f"tpu={len(devices)}",), interleaved_engine_create_fns=(lambda a: engine,), - ) + ) + print(f"server_config: {server_config}") jetstream_server = server_lib.run( From 81abc4f12aec4b747721f57187deec0f19b67189 Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 5 Jun 2024 23:29:00 +0000 Subject: [PATCH 2/3] add run_server wity ray --- run_server_with_ray.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 2059ce91..75644c50 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -117,11 +117,13 @@ def main(argv: Sequence[str]): if FLAGS.is_disaggregated: prefill_engine_list, decode_engine_list = create_disaggregated_engine() + chips = int(len(devices)/2) server_config = ServerConfig( - prefill_slices=(f"tpu={len(devices)}",), + prefill_slices=(f"tpu={chips}",), prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), - generate_slices=(f"tpu={len(devices)}",), + generate_slices=(f"tpu={chips}",), generate_engine_create_fns=(lambda a: decode_engine_list[0],), + is_ray_backend=True ) else: From bfc096f2002eadb1fcf3c3c8c1a851349f1bbf7f Mon Sep 17 00:00:00 2001 From: FanhaiLu1 Date: Wed, 5 Jun 2024 23:40:16 +0000 Subject: [PATCH 3/3] format --- jetstream_pt/ray_engine.py | 4 +- run_interactive_disaggregated.py | 32 ++++++++-------- run_server_with_ray.py | 64 ++++++++++++++++---------------- 3 files changed, 51 insertions(+), 49 deletions(-) diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index f4f05329..d56f4ead 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -180,7 +180,9 @@ def create_pytorch_ray_engine( decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, -) -> Union[PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]]: +) -> Union[ + PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]] +]: # Return tuple as reponse: issues/107 supported_models = ["llama-2", "llama-3", "gemma"] diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index 12749fbb..6f908266 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -94,21 +94,23 @@ def create_disaggregated_engines(): os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" start = time.perf_counter() - prefill_engine_list, decode_engine_list = ray_engine.create_pytorch_ray_engine( - model_name=_MODEL_NAME.value, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, - is_disaggregated=_IS_DISAGGREGATED.value, - num_hosts=_NUM_HOSTS.value, - decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + prefill_engine_list, decode_engine_list = ( + ray_engine.create_pytorch_ray_engine( + model_name=_MODEL_NAME.value, + tokenizer_path=_TOKENIZER_PATH.value, + ckpt_path=_CKPT_PATH.value, + bf16_enable=True, + param_size=_SIZE.value, + context_length=_CONTEXT_LENGTH.value, + batch_size=_BATCH_SIZE.value, + quantize_weights=_QUANTIZE_WEIGHTS.value, + quantize_kv=_QUANTIZE_KV_CACHE.value, + max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, + is_disaggregated=_IS_DISAGGREGATED.value, + num_hosts=_NUM_HOSTS.value, + decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + ) ) print("Initialize engine", time.perf_counter() - start) diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 75644c50..75c41164 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -41,13 +41,9 @@ "is_disaggregated", False, "Disaggregated serving if it's True" ) -flags.DEFINE_integer( - "num_hosts", 4, "Number of TPU host", required=False -) +flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False) -flags.DEFINE_string( - "decode_pod_slice_name", "", "Decode pod slice name" -) +flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") def create_engine(): @@ -75,29 +71,32 @@ def create_engine(): print("Initialize engine", time.perf_counter() - start) return engine + def create_disaggregated_engine(): """create a pytorch engine""" jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" start = time.perf_counter() - prefill_engine_list, decode_engine_list = ray_engine.create_pytorch_ray_engine( - model_name=FLAGS.model_name, - tokenizer_path=FLAGS.tokenizer_path, - ckpt_path=FLAGS.checkpoint_path, - bf16_enable=FLAGS.bf16_enable, - param_size=FLAGS.size, - context_length=FLAGS.context_length, - batch_size=FLAGS.batch_size, - quantize_weights=FLAGS.quantize_weights, - quantize_kv=FLAGS.quantize_kv_cache, - max_cache_length=FLAGS.max_cache_length, - sharding_config=FLAGS.sharding_config, - enable_jax_profiler=FLAGS.enable_jax_profiler, - jax_profiler_port=FLAGS.jax_profiler_port, - is_disaggregated=FLAGS.is_disaggregated, - num_hosts=FLAGS.num_hosts, - decode_pod_slice_name=FLAGS.decode_pod_slice_name, + prefill_engine_list, decode_engine_list = ( + ray_engine.create_pytorch_ray_engine( + model_name=FLAGS.model_name, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + enable_jax_profiler=FLAGS.enable_jax_profiler, + jax_profiler_port=FLAGS.jax_profiler_port, + is_disaggregated=FLAGS.is_disaggregated, + num_hosts=FLAGS.num_hosts, + decode_pod_slice_name=FLAGS.decode_pod_slice_name, + ) ) print("Initialize engine", time.perf_counter() - start) @@ -114,23 +113,22 @@ def main(argv: Sequence[str]): print(f"devices: {devices}") - if FLAGS.is_disaggregated: prefill_engine_list, decode_engine_list = create_disaggregated_engine() - chips = int(len(devices)/2) + chips = int(len(devices) / 2) server_config = ServerConfig( - prefill_slices=(f"tpu={chips}",), - prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), - generate_slices=(f"tpu={chips}",), - generate_engine_create_fns=(lambda a: decode_engine_list[0],), - is_ray_backend=True + prefill_slices=(f"tpu={chips}",), + prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), + generate_slices=(f"tpu={chips}",), + generate_engine_create_fns=(lambda a: decode_engine_list[0],), + is_ray_backend=True, ) - + else: engine = create_engine() server_config = ServerConfig( - interleaved_slices=(f"tpu={len(devices)}",), - interleaved_engine_create_fns=(lambda a: engine,), + interleaved_slices=(f"tpu={len(devices)}",), + interleaved_engine_create_fns=(lambda a: engine,), ) print(f"server_config: {server_config}")