Skip to content

[ROCm] Improve perf for elementwise broadcast with mixed dtype#163562

Closed
jerrymannil wants to merge 1 commit intopytorch:mainfrom
jerrymannil:patch-1
Closed

[ROCm] Improve perf for elementwise broadcast with mixed dtype#163562
jerrymannil wants to merge 1 commit intopytorch:mainfrom
jerrymannil:patch-1

Conversation

@jerrymannil
Copy link
Collaborator

@jerrymannil jerrymannil commented Sep 22, 2025

* Unroll loops manually to hide memory access latency
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163562

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job

As of commit e774647 with merge base e558f7a (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Sep 22, 2025
@jerrymannil
Copy link
Collaborator Author

Reproducer:

import time
import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--events", action="store_true", help="Use CUDA events")
parser.add_argument("--check", action="store_true", help="Enable correctness check")
args = parser.parse_args()

shapes = [[(34816, 1), (34816, 3840)]]

for shape in shapes:
    a = torch.randn(shape[0], device='cuda', dtype=torch.float)
    b = torch.randn(shape[1], device='cuda', dtype=torch.bfloat16)
    for i in range(20):
        if args.check and i == 5:
            a_cpu = a.cpu()
            b_cpu = b.cpu()
            c_cpu = torch.mul(a_cpu, b_cpu)
            c = torch.mul(a, b)
            assert torch.equal(c.cpu(), c_cpu)
        _ = torch.mul(a, b)
    torch.cuda.synchronize()

    if args.events:
        start_evt = torch.cuda.Event(enable_timing=True)
        end_evt = torch.cuda.Event(enable_timing=True)
        start_evt.record()
    else:
        start_time = time.perf_counter_ns()

    for _ in range(100):
        c = torch.mul(a, b)

    if args.events:
        end_evt.record()
    else:
         torch.cuda.synchronize()
         end_time = time.perf_counter_ns()

    if args.events:
        torch.cuda.synchronize()

    if args.events:
        print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
    else:
        print(f"Avg time for shape {shape}: {(end_time - start_time) / (100 * 1e3):.2f} us")

Results on MI325X

Before:
Avg time for shape [(34816, 1), (34816, 3840)]: 432.10 us

After:
Avg time for shape [(34816, 1), (34816, 3840)]: 381.74 us

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Sep 22, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "change is completely inside ifdef USE_ROCM, ROCm CI is passing"

@jerrymannil
Copy link
Collaborator Author

The single failure is some intermittent issue.
I am able to run it fine in my local setup

 python test/inductor/test_cuda_repro.py -k "test_repeated_masked_load" --verbose
test_repeated_masked_load (__main__.CudaReproTests.test_repeated_masked_load) ... expected failure

----------------------------------------------------------------------
Ran 1 test in 0.004s

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@jerrymannil jerrymannil deleted the patch-1 branch September 23, 2025 17:52
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…ch#163562)

* Unroll loops manually to hide memory access latency

Co-author: @amd-hhashemi

Pull Request resolved: pytorch#163562
Approved by: https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants