Skip to content

[ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half#167233

Closed
jerrymannil wants to merge 2 commits intopytorch:mainfrom
jerrymannil:patch-1
Closed

[ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half#167233
jerrymannil wants to merge 2 commits intopytorch:mainfrom
jerrymannil:patch-1

Conversation

@jerrymannil
Copy link
Collaborator

@jerrymannil jerrymannil commented Nov 6, 2025

  • c10::fetch_and_cast and c10::cast_and_store produce branchy code since it supports all datatypes
  • So, we do special handling for binary elementwise broadcast with mixed dtypes of float/bfloat16/half
  • This improves performance

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167233

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 598d9b3 with merge base 73078f3 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: rocm AMD GPU support for Pytorch release notes: cuda release notes category labels Nov 6, 2025
@jerrymannil jerrymannil marked this pull request as draft November 6, 2025 18:03
@jerrymannil
Copy link
Collaborator Author

jerrymannil commented Nov 6, 2025

Reproducer:

import time
import argparse
import torch

parser = argparse.ArgumentParser()
parser.add_argument("--events", action="store_true", help="Use CUDA events")
parser.add_argument("--check", action="store_true", help="Enable correctness check")
args = parser.parse_args()

shapes = [[(34816, 1), (34816, 3840)]]


for i, shape in enumerate(shapes):
    a = torch.randn(shape[0], device='cuda', dtype=torch.float)
    b = torch.randn(shape[1], device='cuda', dtype=torch.bfloat16)
    for i in range(20):
        if args.check and i == 5:
            a_cpu = a.cpu()
            b_cpu = b.cpu()
            c_cpu = torch.mul(a_cpu, b_cpu)
            c = torch.mul(a, b)
            torch.cuda.synchronize()
            assert torch.equal(c.cpu(), c_cpu)
        _ = torch.mul(a, b)
    torch.cuda.synchronize()

    if args.events:
        start_evt = torch.cuda.Event(enable_timing=True)
        end_evt = torch.cuda.Event(enable_timing=True)
        start_evt.record()
    else:
        start_time = time.perf_counter_ns()

    for _ in range(100):
        c = torch.mul(a, b)

    if args.events:
        end_evt.record()
    else:
         torch.cuda.synchronize()
         end_time = time.perf_counter_ns()

    if args.events:
        torch.cuda.synchronize()

    if args.events:
        print(f"Avg time for shape {shape}: {start_evt.elapsed_time(end_evt) / 100 * 1e3:.2f} us")
    else:
        print(f"Avg time for shape {shape}: {(end_time - start_time) / (100 * 1e3):.2f} us")

Results (MI300X):

Before:
Avg time for shape [(34816, 1), (34816, 3840)]: 406.78 us

After:
Avg time for shape [(34816, 1), (34816, 3840)]: 231.73 us

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Nov 6, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Nov 6, 2025
@jerrymannil jerrymannil marked this pull request as ready for review November 6, 2025 21:13
@jeffdaily jeffdaily added the ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 label Nov 6, 2025
@jerrymannil
Copy link
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 7, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 7, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 7, 2025
…es with float/bfloat16/half (#2791)

cherry-pick of pytorch#167233

Fixes #SWDEV-551924
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 7, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 7, 2025
…es with float/bfloat16/half (#2795)

cherry-pick of pytorch#167233

Fixes #SWDEV-551924
@jerrymannil jerrymannil deleted the patch-1 branch November 7, 2025 03:15
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
…es with float/bfloat16/half (pytorch#167233)

* `c10::fetch_and_cast` and `c10::cast_and_store` produce branchy code since it supports all datatypes
* So, we do special handling for binary elementwise broadcast with mixed dtypes of float/bfloat16/half
* This improves performance

Pull Request resolved: pytorch#167233
Approved by: https://github.com/jeffdaily
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 19, 2025
jerrymannil added a commit to ROCm/pytorch that referenced this pull request Nov 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants