Bugfix to forward autodiff causing different datatype 2#165784
Bugfix to forward autodiff causing different datatype 2#165784skpark-rh wants to merge 65 commits intopytorch:mainfrom
Conversation
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…python side to handle dtype promotions.
…numbers. Then using the correct dtype promotions on the python side.
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
…python side to handle dtype promotions.
…numbers. Then using the correct dtype promotions on the python side.
…rch into bugfix/dtype_foward_agrad
…erations caused dtypes to be different.
… tensor from a wrapped number.
… a "is_wrapped_number" as true if the derived derivated is also a wrapped number.
|
I was able to implement the changes requested. Let me know if I need to change something else. Thanks! |
| TORCH_INTERNAL_ASSERT(tensor.device().is_cpu()); | ||
| if (tensor.unsafeGetTensorImpl()->is_wrapped_number() || | ||
| (tensor._is_zerotensor() && | ||
| tensor.unsafeGetTensorImpl()->is_wrapped_number() && |
There was a problem hiding this comment.
This branch of the OR conjunction is pointless since if it's true, the lhs is always true too
There was a problem hiding this comment.
Yeah, that's true...
| return Tensor(); | ||
| } | ||
|
|
||
| void update_wrapped_number(Tensor& input, Tensor& output) { |
There was a problem hiding this comment.
severe parameter ordering blindness here lol
There was a problem hiding this comment.
Yeah... for sure. T_T
…c type for failed builds.
|
I had build fails and some tests that failed with complex dtypes. I just pushed fixes and they should all pass now. |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: trunk / win-vs2022-cpu-py3 / build, trunk / win-vs2022-cuda12.8-py3 / build Details for Dev Infra teamRaised by workflow job |
|
Two win builds are failing with these errors: [edit]: Found a nonstandard dtype |
|
@pytorchmergebot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…calars) Co-authored-by: dilililiwhy<why.wuhuanyu@huawei.com> # message auto-generated for no-merge-commit merge: !26081 merge main_sync_20251028 into master TORCH MAIN SYNC : add update_wrapped_number (bugfix to ForwardADWithScalars) Created-by: dilililiwhy Commit-by: dilililiwhy Merged-by: ascend-robot Description: <!-- Thanks for sending a pull request! --> **What type of PR is this?** > Uncomment only one ` /kind <>` line, hit enter to put that in a new line, and remove leading whitespaces from that line: > > /kind bug > /kind task > /kind feature **What does this PR do / why do we need it**: 2.10.0.dev20251110 **Which issue(s) this PR fixes**: <!-- *Automatically closes linked issue when PR is merged. Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`. --> Fixes # **Special notes for your reviewers**: pytorch/pytorch#160513 pytorch/pytorch#165784 pytorch/pytorch#166657 See merge request: Ascend/pytorch!26081
Fixes #160513
The Problem Summary
The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/init.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number.
The Fix
The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the
is_wrapped_numbermethod but this came with a big issue. During theforward_adthe derivative of those scalars turned out to beZeroTensors. TheZeroTensorinternally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially theis_number_wrapped_property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests foris_wrapped_number_requiresdim > 0and a scalarZeroTensoris a meta dtype tensor which complicates things.So I chose the route of creating a new property called
was_wrapped_numberand exposed this property to the python tensor API. I had to modify the autograd code generation to setwas_wrapped_numberin the mul, add, and div operations inVariableType.cpp. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed.I wrote a new ops testing module
TestForwardADWithScalars. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations.[edit]: Just used
efficientzerotensormeta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved. The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in theBinaryOps.cppto get appropriate dtype for resulting arithmetic.@ezyang @OihanJoyot
cc @EikanWang @jgong5 @wenzhe-nrv @sanchitintel