Skip to content

Conversation

@ZelboK
Copy link
Contributor

@ZelboK ZelboK commented Apr 24, 2024

Fixes #121965

This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include complex<float> for now as complex<double>, for example, will be more complicated.

C++ unit tests are currently passing alongside tests in test_scatter_gather_ops.py. Python test suites also seem to be passing.

Please keep the following in mind:

  1. I think this is my first time using Pytorch.
  2. This is my first contribution to Pytorch.

Environment:
3080 & WSL 2. nvcc is at 12.4.

cc @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @jgong5 @mingfeima @sanchitintel @ashokei @jingxu10 @min-jean-cho @yanbing-j @Guobing-Chen @Xia-Weiwen @snadampal

@pytorch-bot
Copy link

pytorch-bot bot commented Apr 24, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124809

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6bdcc8d with merge base b96b1e8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 24, 2024

@janeyx99 Please let me know if I need to do anything else.

@janeyx99 janeyx99 requested review from mikaylagawarecki, peterbell10 and ptrblck and removed request for peterbell10 April 24, 2024 01:59
@janeyx99
Copy link
Contributor

Requesting for review from @mikaylagawarecki who's worked with scatter_gather and also from @ptrblck regarding the cuda side.

@ptrblck ptrblck requested a review from eqy April 24, 2024 16:18
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 24, 2024
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Apr 24, 2024

CLA Signed

The committers listed above are authorized under a signed CLA.

eqy
eqy previously approved these changes Apr 24, 2024
@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 25, 2024

@mikaylagawarecki @eqy I see that the pipeline is failing because of a linting issue. I used the lintrunner - I do not believe I touched this line on my own. https://github.com/pytorch/pytorch/actions/runs/8824920734/job/24228735042#step:11:244

I can change it if need be

@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 26, 2024

@eqy Anything that needs to be done on my end?

@mikaylagawarecki
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 26, 2024
@mikaylagawarecki mikaylagawarecki added release notes: python_frontend python frontend release notes category topic: improvements topic category and removed ciflow/trunk Trigger trunk jobs on your pull request labels Apr 26, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Approvers from one of the following sets are needed:

  • superuser (pytorch/metamates)
  • Core Reviewers (mruberry, lezcano, Skylion007, ngimel, peterbell10)
  • Core Maintainers (soumith, gchanan, ezyang, dzhulgakov, malfet)
Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@ZelboK
Copy link
Contributor Author

ZelboK commented Apr 30, 2024

@eqy @mikaylagawarecki

Hi folks, since I'm new to Pytorch I'm curious to know what the procedure is now. Will this be reviewed by a core maintainer/contributor by way of triage? Do I need to do anything on my end?

@mikaylagawarecki
Copy link
Contributor

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 30, 2024
@kit1980
Copy link
Contributor

kit1980 commented May 2, 2024

I'm going to revert this at multiple internal builds failed with things like

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(374): error: calling a constexpr __host__ function("operator*=") from a __device__ function("operator()") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(450): error: calling a constexpr __host__ function("max") from a __device__ function("complex_max") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

@mikaylagawarecki see D56861849 if you want to help re-landing this.

@kit1980
Copy link
Contributor

kit1980 commented May 2, 2024

@pytorchbot revert -m "breaking internal builds" -c ghfirst

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 2, 2024
…for comp… (#124809)"

This reverts commit 9e24c26.

Reverted #124809 on behalf of https://github.com/kit1980 due to breaking internal builds ([comment](#124809 (comment)))
@pytorchmergebot
Copy link
Collaborator

@ZelboK your PR has been successfully reverted.

@pytorch-bot pytorch-bot bot dismissed mikaylagawarecki’s stale review May 2, 2024 21:36

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

@ZelboK
Copy link
Contributor Author

ZelboK commented May 3, 2024

@mikaylagawarecki Really sorry for the trouble :(

I'm going to revert this at multiple internal builds failed with things like

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(374): error: calling a constexpr __host__ function("operator*=") from a __device__ function("operator()") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(450): error: calling a constexpr __host__ function("max") from a __device__ function("complex_max") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

@mikaylagawarecki see D56861849 if you want to help re-landing this.

Is there a way for me to replicate one of these environments? Hopefully a devcontainer? I thought there was a workflow for internal builds that needed to pass before merging

petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
pytorch#124809)

Fixes pytorch#121965

This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.

C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.

Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.

Environment:
3080 & WSL 2. `nvcc` is at 12.4.

Pull Request resolved: pytorch#124809
Approved by: https://github.com/eqy, https://github.com/mikaylagawarecki
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
…for comp… (#124809)"

This reverts commit e09f98c.

Reverted #124809 on behalf of https://github.com/clee2000 due to windows build failure is real, https://github.com/pytorch/pytorch/actions/runs/8910674030/job/24470387612#step:11:11236 is the correct failure line, ignore the statement saying build passed, batch is errorcodes arent propagating again ([comment](#124809 (comment)))
pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
#124809)

Fixes #121965

This PR hopes to add support complex numbers in the scatter/gather related kernels. For brevity, I will only include `complex<float>` for now as `complex<double>`, for example, will be more complicated.

C++ unit tests are currently passing alongside tests in `test_scatter_gather_ops.py`. Python test suites also seem to be passing.

Please keep the following in mind:
1) I think this is my first time using Pytorch.
2) This is my first contribution to Pytorch.

Environment:
3080 & WSL 2. `nvcc` is at 12.4.

Pull Request resolved: #124809
Approved by: https://github.com/mikaylagawarecki
GPU_ATOMIC_INTEGER(Mul, a * b, int32_t)
GPU_ATOMIC_INTEGER(Mul, a * b, int64_t)

inline __device__ c10::complex<float> gpuAtomicMul(c10::complex<float> *address, c10::complex<float> val){
Copy link
Contributor

@mikaylagawarecki mikaylagawarecki May 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The build failure is totally not your fault as it can't be seen from external CI, we only see it when internal workflows run after the PR is merged

Looking at the failure and pattern matching a bit, it looks like maybe we need __host__ __device__ here as well as for complex_max on 433.

Does this change make sense? I can import the PR and see whether this fixes the internal build tomorrow morning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately since complex_min and complex_max both use CUDA intrinsics, they won't compile if you make it a __host__ function as well. The use of __fsqrt_rn for example should lead to more performant code/better CUDA assembly. CUDA intrinsics should be taken advantage of imo because it's kernel code and complex numbers are heavier computations to make in general.

The easiest solution would be to add an overload for complex when compiled with CUDA to have an operator*= available with __host__ __device__.

Just adding this

#if defined(__CUDACC__) || defined(__HIPCC__)
  template <typename U>
  C10_HOST_DEVICE constexpr complex<T>& operator*=(const complex<U>& rhs) {
    // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i
    T a = real_;
    T b = imag_;
    U c = rhs.real();
    U d = rhs.imag();
    real_ = a * c - b * d;
    imag_ = a * d + b * c;
    return *this;
  }
#endif

in complex.h should fix this problem.

Also on second look I made an oversight in the complex_max and complex_min functions. They should be using regular comparisons and not std::max given it's a __device__ function. So on that note, it's actually good that this PR got reverted! I will push those changes and things should build on your end.

@pytorch-bot pytorch-bot bot added the module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration label May 3, 2024
@facebook-github-bot
Copy link
Contributor

@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@mikaylagawarecki
Copy link
Contributor

@ZelboK I think the *= update in complex.h you mentioned is still needed, as the internal build still has errors like such

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(374): error: calling a constexpr __host__ function("operator*=") from a __device__ function("operator()") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

Also you seem to have accidentally commited ideep and TensorBase.cpp in this PR 😅

I can import this again to check after you make these fixes, so do let me know when!

fyi: In case you don't get a response from me, wanted to let you know that I will be out this coming week but will be back on Monday (5/13)

@ZelboK
Copy link
Contributor Author

ZelboK commented May 3, 2024

@ZelboK I think the *= update in complex.h you mentioned is still needed, as the internal build still has errors like such

fbcode/caffe2/aten/src/ATen/cuda/Atomic.cuh(374): error: calling a constexpr __host__ function("operator*=") from a __device__ function("operator()") is not allowed. The experimental flag '--expt-relaxed-constexpr' can be used to allow this.

Also you seem to have accidentally commited ideep and TensorBase.cpp in this PR 😅

I can import this again to check after you make these fixes, so do let me know when!

fyi: In case you don't get a response from me, wanted to let you know that I will be out this coming week but will be back on Monday (5/13)

😂 my bad, didn't mean to commit that.

All good, thanks for the prompt assistance! I enjoyed working through this, including the build failures and all. I appreciate the helpfulness from the team ❤️

I'll push

old = atomicCAS(addr_as_ull, assumed, new_val);
} while (assumed != old);

return *reinterpret_cast<c10::complex<float>*>(&addr_as_ull);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't atomic? You need to return csum directly, otherwise the value at addr_as_ull may change underneath you.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact this is also wrong as atomic read-modify-write ops return the old value, not the new value. So this should be bit-casting assumed.

Copy link
Contributor Author

@ZelboK ZelboK May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact this is also wrong as atomic read-modify-write ops return the old value, not the new value. So this should be bit-casting assumed.

Sorry for the oversight. Could you help me understand? I know that atomicCAS returns the old value but with what in mind are you referring that to?

I understand that addr_as_ull sholdn't be returned, as as another thread can change it correct? Why are we to use assumed though and not csum?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assumed is the value before performing the update, which is what is returned by normal atomicAdd, atomicMax, etc.

See the CAS implementation for half as an example:

hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
return hsum;

Copy link
Contributor Author

@ZelboK ZelboK May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh you're right. I got tunnel visioned on the line on the actual call of atomicCAS, yes it should be assumed. i forgot that I am actually implementing an atomic operation here and that it should follow suit lol

__fmul_rn(b.imag(),b.imag())
)
);
return (a_magnitude > b_magnitude) ? a : b;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any precedence for this definition of complex max/min in PyTorch?

Copy link
Contributor Author

@ZelboK ZelboK May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not experienced enough with Pytorch to answer that. Aside from using magnitudes how else would you order them? I followed convention from other ecosystems and from my research this is how it is done across different disciplines/domains, is it not?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly, they cannot be ordered (in mathematical terms, complex numbers are not an ordered field)
We should error in these cases, same as we error when we call max on a complex tensor. If people want to use these ops on complex tensors, they can do a view_as_real and perform some transformations on the output to define the order they want.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that some contexts will use magnitude for ordering complex numbers, like spectral analysis for DSP. I also took motivation from https://www.mathworks.com/help/matlab/ref/max.html as well

@Franklalalala

Could you comment on whether or not you had a use case for scattering complex numbers? What kind of work were you trying to do? Would you know if ordering of complex numbers is practically useful?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be implemented and it may be useful, but we don't implement that in PyTorch at a kernel level. As mentioned above, all these orderings can often be simulated with the current API and a bit of imagination :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - in that case I'll wait until @mikaylagawarecki has a chance to review again. Thanks for taking a look!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZelboK I am working on Tensor_network, which require a series of matrix multiplication. In the case of complex elements, the torch scatter connot be used in GPU. As far as I concerned right now, we do not use sortage here, just elements' multiplication.
By the way, we have worked a way out, that is, we transform the complex number through Euler transformation and turns the multiplication to addition of angles and multiplication of magnitude.
The excellent work of you guys has reached out of my knowledge base, I connot give anymore advices. But thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZelboK I am working on Tensor_network, which require a series of matrix multiplication. In the case of complex elements, the torch scatter connot be used in GPU. As far as I concerned right now, we do not use sortage here, just elements' multiplication. By the way, we have worked a way out, that is, we transform the complex number through Euler transformation and turns the multiplication to addition of angles and multiplication of magnitude. The excellent work of you guys has reached out of my knowledge base, I connot give anymore advices. But thanks!

Thanks a lot for responding, I was genuinely curious. This helps give me perspective :)

@github-actions
Copy link
Contributor

github-actions bot commented Jul 5, 2024

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 5, 2024
@github-actions github-actions bot closed this Aug 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged module: mkldnn Related to Intel IDEEP or oneDNN (a.k.a. mkldnn) integration open source release notes: python_frontend python frontend release notes category Reverted Stale topic: improvements topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

scatter_reduce method do not support complex number multiplication on CUDA