2

I'm studying memory allocation of JAX to make my code faster and I found GPU memory usage of JAX even though I set it to use only CPU. My little code is:

import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'cpu')

x=jnp.zeros(10)

for i in range(10000000000):
    1+1

The for loop is just to see whether this program is using GPU or not.

After this, what I found is it always uses 253MiB of GPU:

303028 oh 20 0 29.5g 377844 294168 R 100.0 0.1 0:08.21 python ./test.py

and

enter image description here

Actually the PID 299133 and 299522 are also using GPU memory with JAX set to use CPU. I'm not sure my actual code is much slower than my c++ code because of this but how can I set it not to use GPU at all?

1
  • Indeed it is the same thing for me. Installing jax[cpu] (CPU-only version) fix the problem but I don't know if it is what you want... Commented Aug 30, 2022 at 21:44

1 Answer 1

3

If you have a gpu-enabled jaxlib installed, JAX will pre-reserve 90% of the GPU memory on import, unless you specify otherwise via appropriate environment variables. Examples are:

  • XLA_PYTHON_CLIENT_PREALLOCATE=false disables preallocation behavior
  • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX specifies the fraction of memory to preallocate

See JAX: GPU Memory Allocation for more information on this.

If you want to avoid GPU memory allocation altogether, one option is to re-install a CPU-only jaxlib:

pip install --force-reinstall "jaxlib[cpu]"
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.