Skip to content

Conversation

@kshitij12345
Copy link
Collaborator

Fixes #42747

@kshitij12345 kshitij12345 marked this pull request as draft August 13, 2020 13:27
@kshitij12345 kshitij12345 marked this pull request as ready for review August 13, 2020 13:27
@dr-ci
Copy link

dr-ci bot commented Aug 13, 2020

💊 CI failures summary and remediations

As of commit 66ee099 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 63 times.

@gchanan gchanan requested a review from mruberry August 13, 2020 18:01
@gchanan gchanan added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 13, 2020
@mruberry mruberry requested a review from anjali411 August 13, 2020 20:43
@anjali411 anjali411 added the module: complex Related to complex number support in PyTorch label Aug 13, 2020
// For complex dtypes.
dot_check(self, other);
return AT_DISPATCH_COMPLEX_TYPES(self.scalar_type(), "vdot", [&] {
Tensor result = at::empty({}, self.options());
Copy link
Collaborator

@mruberry mruberry Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::empty(0, ...

This actually trips up experienced PyTorch developers all the time. If we take a look at empty:

Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional<c10::MemoryFormat> optional_memory_format) {

the size comes from here:

int64_t nelements = prod_intlist(size);

which is computed using

inline int64_t prod_intlist(ArrayRef<int64_t> list) {

And if we run the following program:

std::cout << prod_intlist({}) << std::endl;  // prints 1
std::cout << prod_intlist(0) << std::endl;   // prints 0

it will print 1, 0. It is kinda confusing that an empty initializer list produces a tensor with one element and putting a zero in produces a tensor with no elements, but I try to remember it by recalling that zero specifies the size.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait. You're immediately populating the value into result, not resizing it (which I assumed since it's such a common pattern). You're doing everything right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Although what about empty(1, ... for readability?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty({}) produces 0-dim tensor with 1 element, empty({1}) produces 1-dim tensor with 1 element. vdot has to return 0-dim, so empty({}) is correct

@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vdot(self, device, dtype):
def compare_with_numpy_bin_op(torch_fn, np_fn, x, y, relaxed_tolerance=False):
if self.device_type == 'cuda':
Copy link
Collaborator

@mruberry mruberry Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edited.

To fix the XLA issue and simplify the code just always do:

y_np = y.cpu().numpy()

There's no harm in calling .cpu() on a CPU tensor.

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks really good as usual, @kshitij12345!

What about adding a derivative, like we have for dot:

- name: dot(Tensor self, Tensor tensor) -> Tensor

And updating method_tests():

('dot', (L,), ((L,),), '', (True,)),

to test it?

@anjali411 for help with the derivative.

self.assertEqual(res1, out)

@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
def test_vdot(self, device, dtype):
Copy link
Collaborator

@mruberry mruberry Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you also add tests for:

  • dot and vdot getting arguments of different dtypes, incorrect number of dimensions, and mismatched devices (this test is a little weird to write in the device generic framework, but you can check that your current device is a cuda device and then create a tensor on that device + a cpu tensor to do it)
  • after the out variant is added, including mismatched out dtype
  • after the method variant is added, you can add a test here, too:

('dot', '', _medium_1d, lambda t, d: [_medium_1d(t, d)],

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

r"""
vdot(x, y) -> Tensor
Computes the dot product (inner product) of two tensors.
Copy link
Collaborator

@mruberry mruberry Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about something like:

"Computes the dot product (inner product) of input and the complex conjugate of other."

And not having to explain in a note how vdot is distinct from dot?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should include the note in the description because it brings more clarity to the description.

"Computes the dot product (inner product) of two tensors. The vdot(a, b) function handles complex numbers differently than dot(a, b). If the first argument is complex the complex conjugate of the first argument is used for the calculation of the dot product."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May still need to inline this note per @anjali411's feedback. I'll let her have final say over doc string.

Copy link
Contributor

@anjali411 anjali411 Aug 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kshitij12345 nit -- can you change the (new line) formatting to:
"Computes the dot product (inner product) of two tensors. The vdot(a, b) function
handles complex numbers differently than dot(a, b). If the first argument is complex
the complex conjugate of the first argument is used for the calculation of the dot product."

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Thanks!


Tensor vdot_cuda(const Tensor& self, const Tensor& other) {
if (!self.is_complex()) {
return dot_cuda(self, other);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> np.vdot(np.array([2]), np.array([2+3j]))
(4+6j)
>>> np.vdot(np.array([2+3j]), np.array([2]))
(4-6j)

maybe it's worth adding a comment that we only call dot_cuda when self is not complex because we want the above mentioned behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm I guess we throw an error for input tensors of different dtype:

>>> torch.dot(torch.tensor([1j]), torch.tensor([2]))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: dot : expected both vectors to have same dtype, but found ComplexFloat and Long

I think we should add type promotion for dot and vdot to better align with numpy.

cc. @mruberry

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type promotion would be nice but I think it's OK if it's not in this first PR. Type promotion is a little tricky to implement today when not using TensorIterator.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I was thinking about it. But since the plan is to use existing dot for real types, wasn't sure as to how to go about having type promotion.

Also other place where it diverges from numpy is that numpy operator supports broadcasting while dot doesn't. We can add broadcasting logic before passing the inputs to dot. What do you feel about it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would file follow-up issues for broadcasting + type promotion. We may even want to add architecture to support doing these things easily outside of TensorIterator.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy vdot does not support broadcasting, and we don't try to follow numpy dot behavior:

In [5]: np.vdot([2,3], [3])                                                                                                                                                                                                         
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-5-6192cfa76e6f> in <module>
----> 1 np.vdot([2,3], [3])

<__array_function__ internals> in vdot(*args, **kwargs)

ValueError: cannot reshape array of size 1 into shape (2,)

Copy link
Contributor

@anjali411 anjali411 Aug 14, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry yeah that sounds good. out of curiosity -- what's tricky about implementing type promotion when not using TensorIterator?

@kshitij12345

if (!self.is_complex()) {
    return dot_cuda(self, other);
}

will give us the desired behavior for (real, complex) dot product when the type promotion is enabled, so we should still add a note documenting that. This is so that when the type promotion is enabled in future, we have the behavior already documented.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type promotion challenges:

  • sometimes have to be careful to preserve your inputs (luckily not the case here)
  • compute the result type
  • cast inputs to the result type
  • validate safe casting to out
  • copy to out (if necessary)
  • write a custom test for your op's type promotion behavior

It's not the end of the world. In this case we would want to change the behavior of dot, too, so it seems separable.

Copy link
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some minor comments. LGTM overall :)

  1. We should add a not implemented definition in derivatives.yaml until the JAX vs tf issue is resolved.
- name: dot(Tensor self, Tensor other) -> Tensor
  self: 'not_implemented("vdot: self")'
  other: 'not_implemented("vdot: other")'
  1. We should not throw an error for input tensors of different dtypes to be more consistent with numpy. cc. @mruberry
    There will be some merge conflicts once this PR is merged #42745

@kshitij12345
Copy link
Collaborator Author

Gentle ping:)

@jeffdaily
Copy link
Collaborator

ROCm CI passed with latest changes. LGTM.

@mruberry
Copy link
Collaborator

You also need to update tensors.rst like torch.rst.

@mruberry mruberry self-requested a review August 25, 2020 05:09
@kshitij12345
Copy link
Collaborator Author

You also need to update tensors.rst like torch.rst.

Thanks had missed that. Have also updated the description of the function.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

jeffdaily added a commit to ROCm/pytorch that referenced this pull request Aug 27, 2020
Revert "Skips some complex tests on ROCm (pytorch#42759)"
This reverts commit 55b1706.

Use new cuda_to_hip_mappings.py from pytorch#43004.
facebook-github-bot pushed a commit that referenced this pull request Aug 31, 2020
Summary:
Revert "Skips some complex tests on ROCm (#42759)".  This reverts commit 55b1706.

Use new cuda_to_hip_mappings.py from #43004.

Fixes #42383 (comment)

CC sunway513

Pull Request resolved: #43744

Reviewed By: glaringlee

Differential Revision: D23391263

Pulled By: ngimel

fbshipit-source-id: ddf734cea3ba69c24f0d79cf1b87c05cdb45ec3d
@kshitij12345
Copy link
Collaborator Author

Gentle Ping :)

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensuring Phabricator diff is stamped.

@anjali411
Copy link
Contributor

@kshitij12345 thanks for the reminder! the FB tests for this PR were failing, hence it took a while. can you rebase the PR?

@codecov
Copy link

codecov bot commented Sep 1, 2020

Codecov Report

Merging #43004 into master will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@           Coverage Diff           @@
##           master   #43004   +/-   ##
=======================================
  Coverage   69.29%   69.29%           
=======================================
  Files         379      379           
  Lines       47036    47038    +2     
=======================================
+ Hits        32592    32594    +2     
  Misses      14444    14444           
Impacted Files Coverage Δ
torch/overrides.py 98.01% <ø> (ø)
torch/utils/hipify/cuda_to_hip_mappings.py 100.00% <ø> (ø)
torch/_tensor_docs.py 100.00% <100.00%> (ø)
torch/_torch_docs.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update a67246b...66ee099. Read the comment docs.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anjali411 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@anjali411 merged this pull request in b6b5ebc.

@kshitij12345 kshitij12345 deleted the develop/numpy/vdot branch September 11, 2020 09:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: complex Related to complex number support in PyTorch open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add torch.vdot similar to numpy.vdot to calculate the complex dot product

8 participants