diff --git a/jetstream_pt/ray_engine.py b/jetstream_pt/ray_engine.py index 13d11edc..d56f4ead 100644 --- a/jetstream_pt/ray_engine.py +++ b/jetstream_pt/ray_engine.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, Union, Tuple, List import numpy as np import ray @@ -180,7 +180,9 @@ def create_pytorch_ray_engine( decode_pod_slice_name: str = None, enable_jax_profiler: bool = False, jax_profiler_port: int = 9999, -) -> Any: +) -> Union[ + PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]] +]: # Return tuple as reponse: issues/107 supported_models = ["llama-2", "llama-3", "gemma"] @@ -254,4 +256,4 @@ def create_pytorch_ray_engine( is_disaggregated=is_disaggregated, pod_slice_name=decode_pod_slice_name, ) - return (prefill_engine, decode_engine) + return ([prefill_engine], [decode_engine]) diff --git a/run_interactive_disaggregated.py b/run_interactive_disaggregated.py index b086d365..6f908266 100644 --- a/run_interactive_disaggregated.py +++ b/run_interactive_disaggregated.py @@ -94,25 +94,27 @@ def create_disaggregated_engines(): os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" start = time.perf_counter() - prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine( - model_name=_MODEL_NAME.value, - tokenizer_path=_TOKENIZER_PATH.value, - ckpt_path=_CKPT_PATH.value, - bf16_enable=True, - param_size=_SIZE.value, - context_length=_CONTEXT_LENGTH.value, - batch_size=_BATCH_SIZE.value, - quantize_weights=_QUANTIZE_WEIGHTS.value, - quantize_kv=_QUANTIZE_KV_CACHE.value, - max_cache_length=_MAX_CACHE_LENGTH.value, - sharding_config=_SHARDING_CONFIG.value, - is_disaggregated=_IS_DISAGGREGATED.value, - num_hosts=_NUM_HOSTS.value, - decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + prefill_engine_list, decode_engine_list = ( + ray_engine.create_pytorch_ray_engine( + model_name=_MODEL_NAME.value, + tokenizer_path=_TOKENIZER_PATH.value, + ckpt_path=_CKPT_PATH.value, + bf16_enable=True, + param_size=_SIZE.value, + context_length=_CONTEXT_LENGTH.value, + batch_size=_BATCH_SIZE.value, + quantize_weights=_QUANTIZE_WEIGHTS.value, + quantize_kv=_QUANTIZE_KV_CACHE.value, + max_cache_length=_MAX_CACHE_LENGTH.value, + sharding_config=_SHARDING_CONFIG.value, + is_disaggregated=_IS_DISAGGREGATED.value, + num_hosts=_NUM_HOSTS.value, + decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value, + ) ) print("Initialize engine", time.perf_counter() - start) - return (prefill_engine, decode_engine) + return (prefill_engine_list[0], decode_engine_list[0]) # pylint: disable-next=all diff --git a/run_server_with_ray.py b/run_server_with_ray.py index 325bc108..75c41164 100644 --- a/run_server_with_ray.py +++ b/run_server_with_ray.py @@ -37,6 +37,14 @@ flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler") flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server") +flags.DEFINE_bool( + "is_disaggregated", False, "Disaggregated serving if it's True" +) + +flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False) + +flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name") + def create_engine(): """create a pytorch engine""" @@ -64,6 +72,37 @@ def create_engine(): return engine +def create_disaggregated_engine(): + """create a pytorch engine""" + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + start = time.perf_counter() + prefill_engine_list, decode_engine_list = ( + ray_engine.create_pytorch_ray_engine( + model_name=FLAGS.model_name, + tokenizer_path=FLAGS.tokenizer_path, + ckpt_path=FLAGS.checkpoint_path, + bf16_enable=FLAGS.bf16_enable, + param_size=FLAGS.size, + context_length=FLAGS.context_length, + batch_size=FLAGS.batch_size, + quantize_weights=FLAGS.quantize_weights, + quantize_kv=FLAGS.quantize_kv_cache, + max_cache_length=FLAGS.max_cache_length, + sharding_config=FLAGS.sharding_config, + enable_jax_profiler=FLAGS.enable_jax_profiler, + jax_profiler_port=FLAGS.jax_profiler_port, + is_disaggregated=FLAGS.is_disaggregated, + num_hosts=FLAGS.num_hosts, + decode_pod_slice_name=FLAGS.decode_pod_slice_name, + ) + ) + + print("Initialize engine", time.perf_counter() - start) + return (prefill_engine_list, decode_engine_list) + + # pylint: disable-next=all def main(argv: Sequence[str]): del argv @@ -74,12 +113,24 @@ def main(argv: Sequence[str]): print(f"devices: {devices}") - engine = create_engine() + if FLAGS.is_disaggregated: + prefill_engine_list, decode_engine_list = create_disaggregated_engine() + chips = int(len(devices) / 2) + server_config = ServerConfig( + prefill_slices=(f"tpu={chips}",), + prefill_engine_create_fns=(lambda a: prefill_engine_list[0],), + generate_slices=(f"tpu={chips}",), + generate_engine_create_fns=(lambda a: decode_engine_list[0],), + is_ray_backend=True, + ) + + else: + engine = create_engine() + server_config = ServerConfig( + interleaved_slices=(f"tpu={len(devices)}",), + interleaved_engine_create_fns=(lambda a: engine,), + ) - server_config = ServerConfig( - interleaved_slices=(f"tpu={len(devices)}",), - interleaved_engine_create_fns=(lambda a: engine,), - ) print(f"server_config: {server_config}") jetstream_server = server_lib.run(