37
37
flags .DEFINE_bool ("enable_jax_profiler" , False , "enable jax profiler" )
38
38
flags .DEFINE_integer ("jax_profiler_port" , 9999 , "port of JAX profiler server" )
39
39
40
+ flags .DEFINE_bool (
41
+ "is_disaggregated" , False , "Disaggregated serving if it's True"
42
+ )
43
+
44
+ flags .DEFINE_integer ("num_hosts" , 4 , "Number of TPU host" , required = False )
45
+
46
+ flags .DEFINE_string ("decode_pod_slice_name" , "" , "Decode pod slice name" )
47
+
40
48
41
49
def create_engine ():
42
50
"""create a pytorch engine"""
@@ -64,6 +72,37 @@ def create_engine():
64
72
return engine
65
73
66
74
75
+ def create_disaggregated_engine ():
76
+ """create a pytorch engine"""
77
+ jax .config .update ("jax_default_prng_impl" , "unsafe_rbg" )
78
+ os .environ ["TF_CPP_MIN_LOG_LEVEL" ] = "0"
79
+
80
+ start = time .perf_counter ()
81
+ prefill_engine_list , decode_engine_list = (
82
+ ray_engine .create_pytorch_ray_engine (
83
+ model_name = FLAGS .model_name ,
84
+ tokenizer_path = FLAGS .tokenizer_path ,
85
+ ckpt_path = FLAGS .checkpoint_path ,
86
+ bf16_enable = FLAGS .bf16_enable ,
87
+ param_size = FLAGS .size ,
88
+ context_length = FLAGS .context_length ,
89
+ batch_size = FLAGS .batch_size ,
90
+ quantize_weights = FLAGS .quantize_weights ,
91
+ quantize_kv = FLAGS .quantize_kv_cache ,
92
+ max_cache_length = FLAGS .max_cache_length ,
93
+ sharding_config = FLAGS .sharding_config ,
94
+ enable_jax_profiler = FLAGS .enable_jax_profiler ,
95
+ jax_profiler_port = FLAGS .jax_profiler_port ,
96
+ is_disaggregated = FLAGS .is_disaggregated ,
97
+ num_hosts = FLAGS .num_hosts ,
98
+ decode_pod_slice_name = FLAGS .decode_pod_slice_name ,
99
+ )
100
+ )
101
+
102
+ print ("Initialize engine" , time .perf_counter () - start )
103
+ return (prefill_engine_list , decode_engine_list )
104
+
105
+
67
106
# pylint: disable-next=all
68
107
def main (argv : Sequence [str ]):
69
108
del argv
@@ -74,12 +113,24 @@ def main(argv: Sequence[str]):
74
113
75
114
print (f"devices: { devices } " )
76
115
77
- engine = create_engine ()
116
+ if FLAGS .is_disaggregated :
117
+ prefill_engine_list , decode_engine_list = create_disaggregated_engine ()
118
+ chips = int (len (devices ) / 2 )
119
+ server_config = ServerConfig (
120
+ prefill_slices = (f"tpu={ chips } " ,),
121
+ prefill_engine_create_fns = (lambda a : prefill_engine_list [0 ],),
122
+ generate_slices = (f"tpu={ chips } " ,),
123
+ generate_engine_create_fns = (lambda a : decode_engine_list [0 ],),
124
+ is_ray_backend = True ,
125
+ )
126
+
127
+ else :
128
+ engine = create_engine ()
129
+ server_config = ServerConfig (
130
+ interleaved_slices = (f"tpu={ len (devices )} " ,),
131
+ interleaved_engine_create_fns = (lambda a : engine ,),
132
+ )
78
133
79
- server_config = ServerConfig (
80
- interleaved_slices = (f"tpu={ len (devices )} " ,),
81
- interleaved_engine_create_fns = (lambda a : engine ,),
82
- )
83
134
print (f"server_config: { server_config } " )
84
135
85
136
jetstream_server = server_lib .run (
0 commit comments