|
| 1 | +import functools |
| 2 | +from typing import Tuple, Callable, List, Optional |
| 3 | +import time |
| 4 | +import dataclasses |
| 5 | + |
| 6 | +import numpy as np |
| 7 | + |
| 8 | +import jax |
| 9 | +import jax.numpy as jnp |
| 10 | +from jax.experimental import mesh_utils, shard_map |
| 11 | +from jax.sharding import PositionalSharding |
| 12 | + |
| 13 | + |
| 14 | +from jax.sharding import Mesh |
| 15 | +from jax.sharding import PartitionSpec |
| 16 | +from jax.sharding import NamedSharding |
| 17 | + |
| 18 | +devices = jax.devices() |
| 19 | +P = PartitionSpec |
| 20 | + |
| 21 | +devices = mesh_utils.create_device_mesh((len(devices),)) |
| 22 | +mesh = Mesh(devices, axis_names=("x",)) |
| 23 | +# y = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) |
| 24 | + |
| 25 | +L = 1 << 15 |
| 26 | + |
| 27 | + |
| 28 | +@dataclasses.dataclass |
| 29 | +class BenchmarkCase: |
| 30 | + """BenchmarkCase.""" |
| 31 | + |
| 32 | + name: str |
| 33 | + function: Callable |
| 34 | + args_shape: List[Tuple] |
| 35 | + args_sharding: List[PartitionSpec] |
| 36 | + profiler_output: Optional[str] = None |
| 37 | + |
| 38 | + |
| 39 | +start_key = jax.random.key(0) |
| 40 | + |
| 41 | + |
| 42 | +def _new_arg(shape, dtype): |
| 43 | + global start_key # pylint: disable=all |
| 44 | + start_key, _ = jax.random.split(start_key) |
| 45 | + with jax.default_device(jax.devices("cpu")[0]): |
| 46 | + if dtype == jnp.int8.dtype: |
| 47 | + return jax.random.randint(start_key, shape, 0, 100, dtype=dtype) |
| 48 | + else: |
| 49 | + return jax.random.normal(start_key, shape, dtype=dtype) + 1 |
| 50 | + |
| 51 | + |
| 52 | +def _new_args(case, dtype): |
| 53 | + args = [] |
| 54 | + for shape, sharding in zip(case.args_shape, case.args_sharding): |
| 55 | + arg = _new_arg(shape, dtype) |
| 56 | + if sharding is not None: |
| 57 | + arg = jax.device_put(arg, NamedSharding(mesh, sharding)) |
| 58 | + args.append(arg) |
| 59 | + return args |
| 60 | + |
| 61 | + |
| 62 | +def _run_case(case, warmup=2, runtimes=5, dtype=jnp.bfloat16.dtype): |
| 63 | + for _ in range(warmup): |
| 64 | + args = _new_args(case, dtype) |
| 65 | + case.function(*args) |
| 66 | + |
| 67 | + stamps = [] |
| 68 | + for i in range(runtimes): |
| 69 | + args = _new_args(case, dtype) |
| 70 | + jax.block_until_ready(args) |
| 71 | + if case.profiler_output is not None and i == (runtimes - 1): |
| 72 | + jax.profiler.start_trace(case.profiler_output) |
| 73 | + start = time.perf_counter() |
| 74 | + jax.block_until_ready(case.function(*args)) |
| 75 | + end = time.perf_counter() |
| 76 | + if case.profiler_output is not None and i == (runtimes - 1): |
| 77 | + jax.profiler.stop_trace() |
| 78 | + stamps.append(end - start) |
| 79 | + return sum(stamps) / runtimes |
| 80 | + |
| 81 | + |
| 82 | +def _llama_ffn(x, w1, w2, w3): |
| 83 | + w1_res = jax.nn.silu((x @ w1).astype(jnp.bfloat16.dtype)) |
| 84 | + w3_res = x @ w3 |
| 85 | + res = (w1_res * w3_res) @ w2 |
| 86 | + return res |
| 87 | + |
| 88 | + |
| 89 | +@jax.jit |
| 90 | +@functools.partial( |
| 91 | + shard_map.shard_map, |
| 92 | + mesh=mesh, |
| 93 | + in_specs=(P(), P(None, "x"), P("x"), P(None, "x")), |
| 94 | + out_specs=(P()), |
| 95 | +) |
| 96 | +def _llama_ffn_shmap(x, w1, w2, w3): |
| 97 | + for _ in range(3): |
| 98 | + x = _llama_ffn(x, w1, w2, w3) |
| 99 | + x = jax.lax.psum(x, "x") |
| 100 | + return x |
| 101 | + |
| 102 | + |
| 103 | +@jax.jit |
| 104 | +def _llama_ffn_spmd(x, w1, w2, w3): |
| 105 | + for _ in range(3): |
| 106 | + x = _llama_ffn(x, w1, w2, w3) |
| 107 | + x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) |
| 108 | + return x |
| 109 | + |
| 110 | + |
| 111 | +dim = 4096 |
| 112 | +multiple_of = 256 |
| 113 | +# hidden_dim = 4 * dim |
| 114 | +# hidden_dim = int(2 * hidden_dim / 3) |
| 115 | +# hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) |
| 116 | +hidden_dim = 11008 |
| 117 | +BATCH = 1024 |
| 118 | + |
| 119 | + |
| 120 | +@jax.jit |
| 121 | +@functools.partial( |
| 122 | + shard_map.shard_map, |
| 123 | + mesh=mesh, |
| 124 | + in_specs=(P("x"),), |
| 125 | + out_specs=(P()), |
| 126 | + check_rep=False, |
| 127 | +) |
| 128 | +def _all_gather(x): |
| 129 | + return jax.lax.all_gather(x, "x") |
| 130 | + |
| 131 | + |
| 132 | +@jax.jit |
| 133 | +@functools.partial( |
| 134 | + shard_map.shard_map, mesh=mesh, in_specs=(P("x"),), out_specs=(P()) |
| 135 | +) |
| 136 | +def _all_reduce(x): |
| 137 | + return jax.lax.psum(x, "x") |
| 138 | + |
| 139 | + |
| 140 | +allcases = [ |
| 141 | + BenchmarkCase( |
| 142 | + name="Matmul replicated", |
| 143 | + function=jax.jit(jnp.matmul), |
| 144 | + args_shape=((L, L), (L, L)), |
| 145 | + args_sharding=(P(), P()), # replicated |
| 146 | + ), |
| 147 | + BenchmarkCase( |
| 148 | + name="Matmul sharded colrow", |
| 149 | + function=jax.jit(jnp.matmul), |
| 150 | + args_shape=((L, L), (L, L)), |
| 151 | + args_sharding=(P(None, "x"), P("x")), # replicated |
| 152 | + ), |
| 153 | + BenchmarkCase( |
| 154 | + name="matmul sharded rowcol", |
| 155 | + function=jax.jit(jnp.matmul), |
| 156 | + args_shape=((L, L), (L, L)), |
| 157 | + args_sharding=(P("x"), P("x", None)), # replicated |
| 158 | + ), |
| 159 | + BenchmarkCase( |
| 160 | + name="all_gather", |
| 161 | + function=_all_gather, |
| 162 | + args_shape=((L, L),), |
| 163 | + args_sharding=(P("x"),), # replicated |
| 164 | + ), |
| 165 | + BenchmarkCase( |
| 166 | + name="all_reduce", |
| 167 | + function=_all_reduce, |
| 168 | + args_shape=((L, L),), |
| 169 | + args_sharding=(P("x"),), # replicated |
| 170 | + ), |
| 171 | + BenchmarkCase( |
| 172 | + name="Llama 3xffn shardmap", |
| 173 | + function=_llama_ffn_shmap, |
| 174 | + args_shape=( |
| 175 | + (BATCH, dim), |
| 176 | + (dim, hidden_dim), |
| 177 | + (hidden_dim, dim), |
| 178 | + (dim, hidden_dim), |
| 179 | + ), |
| 180 | + args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")), |
| 181 | + ), |
| 182 | + BenchmarkCase( |
| 183 | + name="Llama 3xffn gspmd", |
| 184 | + function=_llama_ffn_spmd, |
| 185 | + args_shape=( |
| 186 | + (BATCH, dim), |
| 187 | + (dim, hidden_dim), |
| 188 | + (hidden_dim, dim), |
| 189 | + (dim, hidden_dim), |
| 190 | + ), |
| 191 | + args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")), |
| 192 | + ), |
| 193 | +] |
| 194 | + |
| 195 | + |
| 196 | +def _run_call_cases(cases): |
| 197 | + for dtype in (jnp.bfloat16.dtype, jnp.int8.dtype): |
| 198 | + for case in cases: |
| 199 | + avg = _run_case(case, dtype=dtype) |
| 200 | + dtype_size = 2 if dtype == jnp.bfloat16.dtype else 1 |
| 201 | + input_sizes = tuple( |
| 202 | + [ |
| 203 | + f"{np.prod(size) * dtype_size / (1<<20) :.6} MiB" |
| 204 | + for size in case.args_shape |
| 205 | + ] |
| 206 | + ) |
| 207 | + print( |
| 208 | + f"{dtype} \t {case.name}: \t{avg * 1000 :.6} ms \t sizes: {input_sizes}" |
| 209 | + ) |
| 210 | + |
| 211 | + |
| 212 | +def main(): |
| 213 | + print("Number of devices: ", len(devices)) |
| 214 | + _run_call_cases(allcases) |
| 215 | + |
| 216 | + |
| 217 | +if __name__ == "__main__": |
| 218 | + main() |
0 commit comments