I'm studying memory allocation of JAX to make my code faster and I found GPU memory usage of JAX even though I set it to use only CPU. My little code is:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
x=jnp.zeros(10)
for i in range(10000000000):
1+1
The for loop is just to see whether this program is using GPU or not.
After this, what I found is it always uses 253MiB of GPU:
303028 oh 20 0 29.5g 377844 294168 R 100.0 0.1 0:08.21 python ./test.py
and
Actually the PID 299133 and 299522 are also using GPU memory with JAX set to use CPU. I'm not sure my actual code is much slower than my c++ code because of this but how can I set it not to use GPU at all?

jax[cpu](CPU-only version) fix the problem but I don't know if it is what you want...