Skip to content

Commit cd9a01b

Browse files
authored
Merge pull request #643 from prbasyal-amd/711-develop-syncup
Develop branch synced with 711 changes
2 parents faa9b9c + 6632d46 commit cd9a01b

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

.wordlist.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ IIO
4040
IRQs
4141
json
4242
KMD
43+
libdw
4344
MACVTAP
4445
megablocks
4546
microarchitectures

docs/install/3rd-party/jax-install.rst

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu contain
121121

122122
.. code-block:: bash
123123
124-
docker pull rocm/dev-ubuntu-22.04:7.0-complete
124+
docker pull rocm/dev-ubuntu-24.04:7.1-complete
125125
126126
2. Launch the Docker container. After pulling the image, launch a container using this command:
127127

@@ -138,13 +138,16 @@ If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu contain
138138
--security-opt seccomp=unconfined \
139139
-v $(pwd):/jax_dir \
140140
--name rocm_jax \
141-
rocm/dev-ubuntu-22.04:7.0-complete /bin/bash
141+
rocm/dev-ubuntu-24.04:7.1-complete /bin/bash
142142
143143
3. Install the latest version of JAX. Inside the running container, install the required version of JAX with ROCm support using pip:
144144

145145
.. code-block:: bash
146146
147-
pip3 install jax[rocm]
147+
pip3 install --break-system-packages jax==0.7.1
148+
pip3 install --break-system-packages jax-rocm7-pjrt==0.7.1
149+
pip3 install --break-system-packages jax-rocm7-plugin==0.7.1
150+
pip3 install --break-system-packages https://github.com/ROCm/jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
148151
149152
4. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.
150153

@@ -156,18 +159,25 @@ If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu contain
156159

157160
.. code-block::
158161
159-
jax==0.4.35
160-
jax-rocm60-pjrt==0.4.35
161-
jax-rocm60-plugin==0.4.35
162-
jaxlib==0.4.35
162+
jax==0.7.1
163+
jax-rocm7-pjrt==0.7.1
164+
jax-rocm7-plugin==0.7.1
165+
jaxlib==0.7.1
163166
164167
5. Explicitly set the ``LLVM_PATH`` environment variable. This helps XLA find ``ld.lld`` in the PATH at runtime.
165168

166169
.. code-block:: bash
167170
168171
export LLVM_PATH=/opt/rocm/llvm
169172
170-
6. Verify the installation of ROCm JAX. See :ref:`jax-verify-installation`.
173+
6. Install ``libdw1`` if needed
174+
175+
.. code-block:: bash
176+
177+
apt update
178+
apt install libdw1
179+
180+
7. Verify the installation of ROCm JAX. See :ref:`jax-verify-installation`.
171181

172182
.. _install-jax-rocm-custom-container:
173183

@@ -206,7 +216,10 @@ Follow these steps if you prefer to install ROCm manually on your host system or
206216

207217
.. code-block:: bash
208218
209-
pip3 install jax[rocm]
219+
pip3 install --break-system-packages jax==0.7.1
220+
pip3 install --break-system-packages jax-rocm7-pjrt==0.7.1
221+
pip3 install --break-system-packages jax-rocm7-plugin==0.7.1
222+
pip3 install --break-system-packages https://github.com/ROCm/jax/releases/download/rocm-jax-v0.7.1/jaxlib-0.7.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
210223
211224
3. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.
212225

@@ -220,46 +233,28 @@ Follow these steps if you prefer to install ROCm manually on your host system or
220233
221234
export LLVM_PATH=/opt/rocm/llvm
222235
223-
5. Apply the namespace patch:
224-
225-
.. code-block:: bash
226-
227-
patch -p1 \
228-
-d "$(python3 -c \"import sysconfig; print(sysconfig.get_paths()['purelib'])\")" \
229-
< jax_rocm_plugin/third_party/jax/namespace.patch
230-
231-
6. Verify the installation of ROCm JAX.
232-
233-
Run the following commands to verify that ROCm JAX is installed correctly:
236+
5. Install ``libdw1`` if needed
234237

235238
.. code-block:: bash
236239
237-
python3 -c "import jax; print(jax.devices())"
238-
python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
239-
240-
Expected output:
241-
242-
.. code-block::
240+
apt update
241+
apt install libdw1
243242

244-
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
245-
246-
.. code-block::
247-
248-
[0 1 2 3 4]
243+
6. Verify the installation of ROCm JAX. See :ref:`jax-verify-installation`.
249244

250245
.. _build-jax-from-source:
251246
.. _build-jax-wheels:
252247

253248
Build JAX from source
254249
--------------------------------------------------------------------------------------
255250

256-
The `<https://github.com/ROCm/rocm-jax>`__ repository contains sources for the ROCm
251+
The `<https://github.com/ROCm/rocm-jax/tree/rocm-jaxlib-v0.7.1>`__ repository contains sources for the ROCm
257252
plugin for JAX as well as Dockerfiles used to build the AMD ``rocm/jax`` images.
258253
For the most up-to-date instructions, refer directly to the instructions in the repository:
259254

260-
- See `Quick build <https://github.com/ROCm/ROCm-jax?tab=readme-ov-file#quickbuild>`__ for concise high-level steps.
255+
- See `Quick build <https://github.com/ROCm/ROCm-jax/tree/rocm-jaxlib-v0.7.1?tab=readme-ov-file#quickbuild>`__ for concise high-level steps.
261256

262-
- See `Building <https://github.com/ROCm/rocm-jax/blob/master/BUILDING.md#building>`__ for more in-depth build instructions and troubleshooting suggestions.
257+
- See `Building <https://github.com/ROCm/rocm-jax/blob/rocm-jaxlib-v0.7.1/BUILDING.md#building>`__ for more in-depth build instructions and troubleshooting suggestions.
263258

264259
.. _jax-verify-installation:
265260

@@ -270,7 +265,7 @@ After launching the container, test whether JAX detects ROCm devices as expected
270265

271266
.. code-block:: bash
272267
273-
python -c "import jax; print(jax.devices())"
268+
python3 -c "import jax; print(jax.devices())"
274269
python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
275270
276271
If the setup is successful, the output should list all available ROCm devices.

0 commit comments

Comments
 (0)