Replies: 1 comment
-
The lowering follows the device placement of the inputs. So for example, one way to get the CPU and GPU lowerings for a function is like this: import jax
import numpy as np
x = np.random.randn(5, 5).astype('float32')
func = jax.numpy.linalg.eig
x_cpu = jax.device_put(x, jax.devices('cpu')[0])
print(jax.jit(func).lower(x_cpu).compile().as_text())
x_gpu = jax.device_put(x, jax.devices('gpu')[0])
print(jax.jit(f).lower(x_gpu).compile().as_text()) If you inspect the output, you should see that the first compiles to CPU lapack calls, and the second compiles to CUDA calls. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Is it possible to compile a lowered trace for a specific hardware? E.g. to compile for a specific GPU on a CPU-only machine to show the generated XLA code with
jax.jit(function).lower(*args).compile(???).as_text()
.Beta Was this translation helpful? Give feedback.
All reactions