Skip to content

Commit 1e08833

Browse files
authored
Add a script to measure speed of basic ops (#168)
* Add a script to measure speed of basic ops * make batch of ffn larger * lint
1 parent eb360ee commit 1e08833

File tree

1 file changed

+218
-0
lines changed

1 file changed

+218
-0
lines changed

benchmarks/basic_ops.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
import functools
2+
from typing import Tuple, Callable, List, Optional
3+
import time
4+
import dataclasses
5+
6+
import numpy as np
7+
8+
import jax
9+
import jax.numpy as jnp
10+
from jax.experimental import mesh_utils, shard_map
11+
from jax.sharding import PositionalSharding
12+
13+
14+
from jax.sharding import Mesh
15+
from jax.sharding import PartitionSpec
16+
from jax.sharding import NamedSharding
17+
18+
devices = jax.devices()
19+
P = PartitionSpec
20+
21+
devices = mesh_utils.create_device_mesh((len(devices),))
22+
mesh = Mesh(devices, axis_names=("x",))
23+
# y = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
24+
25+
L = 1 << 15
26+
27+
28+
@dataclasses.dataclass
29+
class BenchmarkCase:
30+
"""BenchmarkCase."""
31+
32+
name: str
33+
function: Callable
34+
args_shape: List[Tuple]
35+
args_sharding: List[PartitionSpec]
36+
profiler_output: Optional[str] = None
37+
38+
39+
start_key = jax.random.key(0)
40+
41+
42+
def _new_arg(shape, dtype):
43+
global start_key # pylint: disable=all
44+
start_key, _ = jax.random.split(start_key)
45+
with jax.default_device(jax.devices("cpu")[0]):
46+
if dtype == jnp.int8.dtype:
47+
return jax.random.randint(start_key, shape, 0, 100, dtype=dtype)
48+
else:
49+
return jax.random.normal(start_key, shape, dtype=dtype) + 1
50+
51+
52+
def _new_args(case, dtype):
53+
args = []
54+
for shape, sharding in zip(case.args_shape, case.args_sharding):
55+
arg = _new_arg(shape, dtype)
56+
if sharding is not None:
57+
arg = jax.device_put(arg, NamedSharding(mesh, sharding))
58+
args.append(arg)
59+
return args
60+
61+
62+
def _run_case(case, warmup=2, runtimes=5, dtype=jnp.bfloat16.dtype):
63+
for _ in range(warmup):
64+
args = _new_args(case, dtype)
65+
case.function(*args)
66+
67+
stamps = []
68+
for i in range(runtimes):
69+
args = _new_args(case, dtype)
70+
jax.block_until_ready(args)
71+
if case.profiler_output is not None and i == (runtimes - 1):
72+
jax.profiler.start_trace(case.profiler_output)
73+
start = time.perf_counter()
74+
jax.block_until_ready(case.function(*args))
75+
end = time.perf_counter()
76+
if case.profiler_output is not None and i == (runtimes - 1):
77+
jax.profiler.stop_trace()
78+
stamps.append(end - start)
79+
return sum(stamps) / runtimes
80+
81+
82+
def _llama_ffn(x, w1, w2, w3):
83+
w1_res = jax.nn.silu((x @ w1).astype(jnp.bfloat16.dtype))
84+
w3_res = x @ w3
85+
res = (w1_res * w3_res) @ w2
86+
return res
87+
88+
89+
@jax.jit
90+
@functools.partial(
91+
shard_map.shard_map,
92+
mesh=mesh,
93+
in_specs=(P(), P(None, "x"), P("x"), P(None, "x")),
94+
out_specs=(P()),
95+
)
96+
def _llama_ffn_shmap(x, w1, w2, w3):
97+
for _ in range(3):
98+
x = _llama_ffn(x, w1, w2, w3)
99+
x = jax.lax.psum(x, "x")
100+
return x
101+
102+
103+
@jax.jit
104+
def _llama_ffn_spmd(x, w1, w2, w3):
105+
for _ in range(3):
106+
x = _llama_ffn(x, w1, w2, w3)
107+
x = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
108+
return x
109+
110+
111+
dim = 4096
112+
multiple_of = 256
113+
# hidden_dim = 4 * dim
114+
# hidden_dim = int(2 * hidden_dim / 3)
115+
# hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
116+
hidden_dim = 11008
117+
BATCH = 1024
118+
119+
120+
@jax.jit
121+
@functools.partial(
122+
shard_map.shard_map,
123+
mesh=mesh,
124+
in_specs=(P("x"),),
125+
out_specs=(P()),
126+
check_rep=False,
127+
)
128+
def _all_gather(x):
129+
return jax.lax.all_gather(x, "x")
130+
131+
132+
@jax.jit
133+
@functools.partial(
134+
shard_map.shard_map, mesh=mesh, in_specs=(P("x"),), out_specs=(P())
135+
)
136+
def _all_reduce(x):
137+
return jax.lax.psum(x, "x")
138+
139+
140+
allcases = [
141+
BenchmarkCase(
142+
name="Matmul replicated",
143+
function=jax.jit(jnp.matmul),
144+
args_shape=((L, L), (L, L)),
145+
args_sharding=(P(), P()), # replicated
146+
),
147+
BenchmarkCase(
148+
name="Matmul sharded colrow",
149+
function=jax.jit(jnp.matmul),
150+
args_shape=((L, L), (L, L)),
151+
args_sharding=(P(None, "x"), P("x")), # replicated
152+
),
153+
BenchmarkCase(
154+
name="matmul sharded rowcol",
155+
function=jax.jit(jnp.matmul),
156+
args_shape=((L, L), (L, L)),
157+
args_sharding=(P("x"), P("x", None)), # replicated
158+
),
159+
BenchmarkCase(
160+
name="all_gather",
161+
function=_all_gather,
162+
args_shape=((L, L),),
163+
args_sharding=(P("x"),), # replicated
164+
),
165+
BenchmarkCase(
166+
name="all_reduce",
167+
function=_all_reduce,
168+
args_shape=((L, L),),
169+
args_sharding=(P("x"),), # replicated
170+
),
171+
BenchmarkCase(
172+
name="Llama 3xffn shardmap",
173+
function=_llama_ffn_shmap,
174+
args_shape=(
175+
(BATCH, dim),
176+
(dim, hidden_dim),
177+
(hidden_dim, dim),
178+
(dim, hidden_dim),
179+
),
180+
args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")),
181+
),
182+
BenchmarkCase(
183+
name="Llama 3xffn gspmd",
184+
function=_llama_ffn_spmd,
185+
args_shape=(
186+
(BATCH, dim),
187+
(dim, hidden_dim),
188+
(hidden_dim, dim),
189+
(dim, hidden_dim),
190+
),
191+
args_sharding=(P(), P(None, "x"), P("x"), P(None, "x")),
192+
),
193+
]
194+
195+
196+
def _run_call_cases(cases):
197+
for dtype in (jnp.bfloat16.dtype, jnp.int8.dtype):
198+
for case in cases:
199+
avg = _run_case(case, dtype=dtype)
200+
dtype_size = 2 if dtype == jnp.bfloat16.dtype else 1
201+
input_sizes = tuple(
202+
[
203+
f"{np.prod(size) * dtype_size / (1<<20) :.6} MiB"
204+
for size in case.args_shape
205+
]
206+
)
207+
print(
208+
f"{dtype} \t {case.name}: \t{avg * 1000 :.6} ms \t sizes: {input_sizes}"
209+
)
210+
211+
212+
def main():
213+
print("Number of devices: ", len(devices))
214+
_run_call_cases(allcases)
215+
216+
217+
if __name__ == "__main__":
218+
main()

0 commit comments

Comments
 (0)