Skip to content

Bugfix to forward autodiff causing different datatype 2#165784

Closed
skpark-rh wants to merge 65 commits intopytorch:mainfrom
skpark-rh:bugfix/dtype_foward_agrad
Closed

Bugfix to forward autodiff causing different datatype 2#165784
skpark-rh wants to merge 65 commits intopytorch:mainfrom
skpark-rh:bugfix/dtype_foward_agrad

Conversation

@skpark-rh
Copy link
Collaborator

@skpark-rh skpark-rh commented Oct 17, 2025

Fixes #160513

The Problem Summary

The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/init.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.

The Fix

The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the is_wrapped_number method but this came with a big issue. During the forward_ad the derivative of those scalars turned out to be ZeroTensors. The ZeroTensor internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the is_number_wrapped_ property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for is_wrapped_number_ requires dim > 0 and a scalar ZeroTensor is a meta dtype tensor which complicates things.

So I chose the route of creating a new property called was_wrapped_number and exposed this property to the python tensor API. I had to modify the autograd code generation to set was_wrapped_number in the mul, add, and div operations in VariableType.cpp. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.

I wrote a new ops testing module TestForwardADWithScalars. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.

[edit]: Just used efficientzerotensor meta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved. The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in the BinaryOps.cpp to get appropriate dtype for resulting arithmetic.

@ezyang @OihanJoyot

cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel

… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…numbers. Then using the correct dtype promotions on the python side.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…numbers. Then using the correct dtype promotions on the python side.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
@skpark-rh skpark-rh requested a review from ezyang November 3, 2025 20:58
@skpark-rh
Copy link
Collaborator Author

skpark-rh commented Nov 3, 2025

I was able to implement the changes requested. Let me know if I need to change something else. Thanks!

TORCH_INTERNAL_ASSERT(tensor.device().is_cpu());
if (tensor.unsafeGetTensorImpl()->is_wrapped_number() ||
(tensor._is_zerotensor() &&
tensor.unsafeGetTensorImpl()->is_wrapped_number() &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch of the OR conjunction is pointless since if it's true, the lhs is always true too

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's true...

return Tensor();
}

void update_wrapped_number(Tensor& input, Tensor& output) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

severe parameter ordering blindness here lol

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... for sure. T_T

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good if ci passes

@skpark-rh
Copy link
Collaborator Author

I had build fails and some tests that failed with complex dtypes. I just pushed fixes and they should all pass now.

@skpark-rh
Copy link
Collaborator Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 4, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / build, trunk / win-vs2022-cuda12.8-py3 / build

Details for Dev Infra team Raised by workflow job

@skpark-rh
Copy link
Collaborator Author

skpark-rh commented Nov 5, 2025

Two win builds are failing with these errors: rm: cannot remove './build/win_tmp/bin': Device or resource busy. Doesn't seem to be an issue with the PR though.

[edit]: Found a nonstandard dtype u_int64_t that was causing compile issues on the window builds.

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Nov 5, 2025
@skpark-rh
Copy link
Collaborator Author

@pytorchmergebot merge

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 5, 2025

Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra.

@skpark-rh
Copy link
Collaborator Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 5, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@skpark-rh skpark-rh deleted the bugfix/dtype_foward_agrad branch November 6, 2025 13:38
drizzlezyk pushed a commit to Ascend/pytorch that referenced this pull request Nov 17, 2025
…calars)

Co-authored-by: dilililiwhy<why.wuhuanyu@huawei.com>



# message auto-generated for no-merge-commit merge:
!26081 merge main_sync_20251028 into master

TORCH MAIN SYNC : add update_wrapped_number (bugfix to ForwardADWithScalars)

Created-by: dilililiwhy
Commit-by: dilililiwhy
Merged-by: ascend-robot
Description: <!--  Thanks for sending a pull request! 
-->

**What type of PR is this?**
> Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line:
>
> /kind bug
> /kind task
> /kind feature


**What does this PR do / why do we need it**:
2.10.0.dev20251110

**Which issue(s) this PR fixes**:
<!-- 
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->
Fixes #

**Special notes for your reviewers**:
pytorch/pytorch#160513
pytorch/pytorch#165784

pytorch/pytorch#166657


See merge request: Ascend/pytorch!26081
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: forward ad oncall: jit Add this issue/PR to JIT oncall triage queue open source release notes: autograd release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Forward autodiff : Multiplying by python float changes the dual dtype in some situations

6 participants