Skip to content

[ROCm] new implementation of upsample_bilinear2d_backward#164572

Closed
glen-amd wants to merge 5 commits intopytorch:mainfrom
glen-amd:fix_to_up_and_grid_sample_backward
Closed

[ROCm] new implementation of upsample_bilinear2d_backward#164572
glen-amd wants to merge 5 commits intopytorch:mainfrom
glen-amd:fix_to_up_and_grid_sample_backward

Conversation

@glen-amd
Copy link
Contributor

@glen-amd glen-amd commented Oct 3, 2025

Changed the implementation from an output-based approach to an input-based one to remove atomicAdd operations, and it appears to deliver at least a 20× speedup.

The changes are from Yu-Yun YuYun.Chang@amd.com.

Summary: Refactor of the implementation of the upsample_bilinear2d_backward opertion on MI300X/MI325X

  • The original "scatter-add" approach
    • Each thread, representing an output pixel, scattered gradient contributions to four input pixels, using costly atomic operations on MI300X/MI325X GPUs.
  • The new "gather-sum" approach
    • Each thread is responsible for a single input pixel and gathers all relevant gradient contributions from a small, calculated region of the output tensor (done by the compute_output_range device function).

Breakdown of the code changes

  • Inversion of the parallelization strategy of the kernel function upsample_bilinear2d_backward_out_frame
    • Originally, the main kernel loop was parallelized over the number of elements in the output gradient tensor (const size_t o_numel = nc * width2 * height2;).
      • Each thread processed one output pixel.
    • The new loop is parallelized over the number of elements in the input gradient tensor (const size_t i_numel = nc * height1 * width1;).
      • Each thread is responsible for calculating the final gradient for a single input pixel.
    • The kernel launch changes accordingly in the function upsample_bilinear2d_backward_out_cuda_template.
  • Added a device function for calculating the range of output pixels that could have possibly used that the input pixel (input_pos) during the forward pass interpolation
    • This is essentially the mathematical inverse of the forward pass.
    • This function tries to prune a thread's search space so that it only needs to inspect a small, local window of the output tensor.
  • Gradient calculation approach switching from "scatter-add" to "gather-sum"
    • Scatter-add
      • For each output pixel, the thread calculated 4 gradient contributions and use fastAtomicAdd 4 times to add these values to 4 different (and potentially highly contended) memory locations in the input gradient tensor.
    • Gather-sum
      • A thread responsible for one input pixel calls compute_output_range to determine the small rectangular region of output pixels that influence the input's final gradient value.
      • The thread iterates through this region, and for each output pixel in the regionre, it re-calculates the interpolation weights to determine the exact contribution to its specific input pixel.
      • All these contributions are accumulated into a private, per-thread register variable (accscalar_t grad_sum = 0;).
        • W/o any gloabl memory access, this accumulation is extremely fast.
      • When the loops are done, the thread performs a single, direct write (non-atomic) of the final summed gradient to its designated location in global memory (idata[index] = static_cast<scalar_t>(grad_sum);).

Why performance gets boosted

  • Analysis of the root cause of performance drop
  • First and foremost, elimination of the contention of atomic operations
    • Many parallel threads called atomicAdd frequently attempting to update the exact same memory location in the input gradient tensor at the same time.
      • The GPU's memory controler has to serialize these operations, effectively nullifying the benefit of parallel capability at those contention points.
    • MI300X/MI325X chiplet-based CDNA 3 architeture amplified the issue.
      • When contending threads reside on different XCDs, resolving the atomic operation requires high-latency coherence traffic across the Infinity Fabric interconnect.
    • The implementation change eliminates hardware-level serialization and cross-chiplet coherence traffic caused by many atomicAdd.
  • Improved memory access pattern and locality
    • Write coalescing
      • The regular sum writes idata[index] = static_cast<scalar_t>(grad_sum); can be perfectly coalesced by GPUs.
    • Read locality
      • Even though there are many (potentially repeated) reads from the output tensor (static_cast<accscalar_t>(odata[output_idx])), these are highly cache-friendly, meaning the data for one thread is likely to be in the L1 or L2 cache already due to an access from a neighboring thread.
  • Trade-off: computation for memory synchronization
    • The recalculation of interpolation weights fits well on high-computational-throughput modern GPUs like MI300X/MI325X.
    • Removal of atomic operations avoids expensive memory synchronization.

Optimizations of grid_sampler_2d_backward will be addressed in a separate PR.
Doc for reference: (internal only) https://amd.atlassian.net/wiki/spaces/~glencao2/pages/1162750701/PyTorch__grid_sampler_2d_backward

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 3, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164572

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c9eb5ce with merge base 60ac039 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: cuda release notes category label Oct 3, 2025
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Oct 3, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@pruthvistony pruthvistony added rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners and removed release notes: cuda release notes category labels Oct 3, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 3, 2025

Unknown label ciflow/rocm-mi355.
Currently recognized labels are

  • ciflow/b200
  • ciflow/b200-symm-mem
  • ciflow/binaries
  • ciflow/binaries_libtorch
  • ciflow/binaries_wheel
  • ciflow/h100
  • ciflow/h100-cutlass-backend
  • ciflow/h100-distributed
  • ciflow/h100-symm-mem
  • ciflow/inductor
  • ciflow/inductor-cu126
  • ciflow/inductor-micro-benchmark
  • ciflow/inductor-micro-benchmark-cpu-x86
  • ciflow/inductor-perf-compare
  • ciflow/inductor-perf-test-nightly-rocm
  • ciflow/inductor-perf-test-nightly-x86-zen
  • ciflow/inductor-periodic
  • ciflow/inductor-rocm
  • ciflow/linux-aarch64
  • ciflow/mps
  • ciflow/nightly
  • ciflow/op-benchmark
  • ciflow/periodic
  • ciflow/periodic-rocm-mi300
  • ciflow/pull
  • ciflow/quantization-periodic
  • ciflow/riscv64
  • ciflow/rocm
  • ciflow/rocm-mi300
  • ciflow/s390
  • ciflow/slow
  • ciflow/torchbench
  • ciflow/triton_binaries
  • ciflow/trunk
  • ciflow/unstable
  • ciflow/vllm
  • ciflow/win-arm64
  • ciflow/xpu

@pytorch-bot pytorch-bot bot added release notes: cuda release notes category and removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/inductor-rocm Trigger "inductor" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners labels Oct 3, 2025
@glen-amd
Copy link
Contributor Author

glen-amd commented Oct 8, 2025

@jeffdaily / @jerrymannil / @amd-hhashemi - please review. Thanks.

@jeffdaily jeffdaily changed the title Changed the implementation from an output-based approach to an input-… [ROCm] new implementation of upsample_bilinear2d_backward Oct 10, 2025
@pytorch-bot pytorch-bot bot added ciflow/rocm Trigger "default" config CI on ROCm module: rocm AMD GPU support for Pytorch labels Oct 10, 2025
@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed release notes: cuda release notes category labels Oct 10, 2025
@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Oct 10, 2025
@glen-amd
Copy link
Contributor Author

@jeffdaily - can you please review and add CI tags? Thanks.

@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 keep-going Don't stop on first failure, keep running tests until the end labels Oct 13, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

Copy link
Contributor

@amd-hhashemi amd-hhashemi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't this zeroing be removed now that you're not using atomics?

@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased fix_to_up_and_grid_sample_backward onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout fix_to_up_and_grid_sample_backward && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the fix_to_up_and_grid_sample_backward branch from 3503eac to c9eb5ce Compare October 24, 2025 21:55
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2025
@glen-amd
Copy link
Contributor Author

Can't this zeroing be removed now that you're not using atomics?

Good call.
In order not to push more changes, I shall be addressing this in the PR for the Grid Sampler optimization (#165337).

@jeffdaily
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 24, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

rocm-repo-management-api bot pushed a commit to ROCm/pytorch that referenced this pull request Oct 28, 2025
rocm-repo-management-api bot pushed a commit to ROCm/pytorch that referenced this pull request Oct 28, 2025
rocm-repo-management-api bot pushed a commit to ROCm/pytorch that referenced this pull request Oct 28, 2025
jeffdaily pushed a commit to ROCm/pytorch that referenced this pull request Nov 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel Reverted rocm This tag is for PRs from ROCm team triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants