Skip to content

Conversation

@xwang233
Copy link
Collaborator

No description provided.

@xwang233
Copy link
Collaborator Author

cc @ptrblck

@dr-ci
Copy link

dr-ci bot commented Aug 25, 2020

💊 CI failures summary and remediations

As of commit 7ff5eb6 (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 1/2 non-CircleCI failure(s)

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_cpp_doc_build (1/1)

Step: "Doc Build and Push" (full log | diagnosis details | 🔁 rerun)

Sep 10 15:40:31 - [class]: torch::serialize::OutputArchive
Sep 10 15:40:31     - [function]: torch::python::bind_module 
Sep 10 15:40:31     - [function]: torch::python::bind_module 
Sep 10 15:40:31     - [function]: torch::python::init_bindings 
Sep 10 15:40:31     - [namespace]: torch::python::detail 
Sep 10 15:40:31       - [function]: torch::python::detail::bind_cpp_module_wrapper 
Sep 10 15:40:31       - [function]: torch::python::detail::py_object_to_device 
Sep 10 15:40:31       - [function]: torch::python::detail::py_object_to_dtype 
Sep 10 15:40:31       - [typedef]: torch::python::detail::PyModuleClass 
Sep 10 15:40:31   - [namespace]: torch::serialize 
Sep 10 15:40:31     - [class]: torch::serialize::InputArchive 
Sep 10 15:40:31     - [class]: torch::serialize::OutputArchive 
Sep 10 15:40:31   - [typedef]: torch::AutoGradMode 
Sep 10 15:40:31   - [typedef]: torch::Deleter 
Sep 10 15:40:31   - [typedef]: torch::Dtype 
Sep 10 15:40:31   - [typedef]: torch::NoGradGuard 
Sep 10 15:40:31   - [variable]: torch::kArea 
Sep 10 15:40:31   - [variable]: torch::kBatchMean 
Sep 10 15:40:31   - [variable]: torch::kBicubic 
Sep 10 15:40:31   - [variable]: torch::kBilinear 
Sep 10 15:40:31   - [variable]: torch::kBorder 
Sep 10 15:40:31   - [variable]: torch::kCircular 

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 11 times.

@xwang233 xwang233 requested a review from ngimel August 25, 2020 17:11
@mruberry mruberry added module: cuda Related to torch.cuda, and CUDA support in general module: half Related to float16 half-precision floats triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 26, 2020
@xwang233 xwang233 changed the title Enable lerp on half type Enable lerp on half type; fix output memory format Aug 26, 2020
@xwang233
Copy link
Collaborator Author

xwang233 commented Sep 1, 2020

@ngimel Is there any update on this?

@xwang233 xwang233 requested a review from ezyang September 3, 2020 04:17
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ezyang has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@ezyang merged this pull request in b5d75dd.

@heitorschueroff
Copy link
Contributor

heitorschueroff commented Sep 11, 2020

@xwang233 This commit broke some changes I was making to torch.quantile when giving at::lerp_out a discontiguous view of out tensor as follows:

>>> values_below = torch.tensor([[[0., 4.],
         [0., 4.],
         [0., 4.],
         [0., 4.],
         [1., 4.],
         [0., 4.],
         [1., 4.],
         [1., 3.]]], device='cuda:0')

>>> values_above = values_below
>>> weights = torch.tensor([0., 0.], device='cuda:0')
>>> out = torch.empty([2, 1, 8], device='cuda:0')

>>> torch.lerp(values_below, values_above, weights, out=out.unsqueeze(-1).transpose_(0, -1).squeeze_(0))
>>> out
tensor([[[0., 4., 0., 4., 0., 4., 0., 4.]],

        [[1., 4., 0., 4., 1., 4., 1., 3.]]], device='cuda:0')

If instead I change the out tensor itself instead of passing a view, then it works correctly as follows:

>>> out.unsqueeze_(-1).transpose_(0, -1).squeeze_(0)
>>> torch.lerp(values_below, values_above, weights, out=out)
>>> out.unsqueeze_(0).transpose_(0, -1).squeeze_(-1)
>>> out
tensor([[[0., 0., 0., 0., 1., 0., 1., 1.]],

        [[4., 4., 4., 4., 4., 4., 4., 3.]]], device='cuda:0')

@ngimel
Copy link
Collaborator

ngimel commented Sep 11, 2020

This is not changed by this PR, it worked like this before. Here

torch.lerp(values_below, values_above, weights, out=out.unsqueeze(-1).transpose_(0, -1).squeeze_(0))

you are sending an intermediate variable as your out kwarg, so out tensor that you have is not expected to contain the result of your operation, an intermediate variable (that you no longer have access to) will.

out1=torch.lerp(vals, values_above, weights, out=out.unsqueeze_(-1).transpose_(0, -1).squeeze_(0))

produces expected results (that is, out and out1 are the same). Note unsqueeze_

@heitorschueroff
Copy link
Contributor

heitorschueroff commented Sep 11, 2020

I removed the lines b_self.suggest_memory_format() added by this PR and torch.lerp(values_below, values_above, weights, out=out.unsqueeze(-1).transpose_(0, -1).squeeze_(0)) worked the same as when using unsqueeze_. But as you said, maybe the current version is how it should be and I was just lucky before.

I changed my code to work with the current version so there's no need to change it if it's working correctly.

@ngimel
Copy link
Collaborator

ngimel commented Sep 11, 2020

That's weird,

  1. colab with pytorch 1.6 produces different out and out1 here https://colab.research.google.com/drive/1lG4h5_ti9SiZcPKElJDFnpqcLxzFT_XN#scrollTo=LpxNXGZQlsic
  2. If you send in tensor with the correct size, resize_as_ does nothing, regardless of requested memory type
  3. suggest_memory_format will never return anything other than contiguous for a 3d tensor, which would also be the resize_as_ behavior before this PR.

@heitorschueroff
Copy link
Contributor

The problem is not that out and out1 are different, that is ok. The problem is that in 1.6 out was correctly computed as

tensor([[[0., 0., 0., 0., 1., 0., 1., 1.]],

        [[4., 4., 4., 4., 4., 4., 4., 3.]]], device='cuda:0')

But with the introduction of this PR, the result from the colab would be

tensor([[[0., 4., 0., 4., 0., 4., 0., 4.]],

        [[1., 4., 0., 4., 1., 4., 1., 3.]]], device='cuda:0')

@ngimel
Copy link
Collaborator

ngimel commented Sep 11, 2020

Oh, thank you, you are right! Yes, this is wrong behavior, normally if out is correct size, we should not change its strides, and now we make it contiguous:

In [36]: out=torch.empty([2,1,8], device="cuda")                                                                                                                                          

In [37]: out_1=out.unsqueeze(-1).transpose(0, -1).squeeze(0)                                                                                                                              

In [38]: out_1.is_contiguous()                                                                                                                                                            
Out[38]: False

In [39]: out_1.stride()                                                                                                                                                                   
Out[39]: (8, 1, 8)

In [40]: out_1.size()                                                                                                                                                                     
Out[40]: torch.Size([1, 8, 2])

In [41]: torch.lerp(vals, values_above, weights, out=out_1)                                                                                                                               
Out[41]: 
tensor([[[0., 4.],
         [0., 4.],
         [0., 4.],
         [0., 4.],
         [1., 4.],
         [0., 4.],
         [1., 4.],
         [1., 3.]]], device='cuda:0')

In [42]: out_1.size()                                                                                                                                                                     
Out[42]: torch.Size([1, 8, 2])

In [43]: out_1.is_contiguous()                                                                                                                                                            
Out[43]: True

In [44]: out_1.stride()                                                                                                                                                                   
Out[44]: (16, 2, 1)

We should change resize_/resize_as_ implementation to respect that, but in the meantime, can you please remove suggest_memory_format from lerp?

@heitorschueroff
Copy link
Contributor

heitorschueroff commented Sep 11, 2020

Sure, removed here #44559.

@xwang233
Copy link
Collaborator Author

@heitorschueroff thanks for the test! I remember the resize with suggest_memory_format() was supposed to fix the test when input tensor is channels_last, so I changed them. If we simply remove them, I'm not sure if output will still be channels_last.

@ngimel I think the problem with out= is that, what are we expecting from out=some_tensor? Should we respect the size and strides of some_tensor passed in by user, or should we just take some_tensor as an empty python variable and assign a new output tensor to it?

I remember there was a similar discussion here #41027,

@heitorschueroff
Copy link
Contributor

@xwang233 I believe the intent is to move away from resizing output tensors unless it is empty (x.numel() == 0) as described here #42079.

bitfort pushed a commit that referenced this pull request Sep 11, 2020
Summary: Pull Request resolved: #43541

Reviewed By: zou3519

Differential Revision: D23499592

Pulled By: ezyang

fbshipit-source-id: 9efdd6cbf0a334ec035ddd467667ba874b892549
@ngimel
Copy link
Collaborator

ngimel commented Sep 11, 2020

What are shape requirements for the result tensor here? Can TensorIterator figure out what size it should be? If so, then you should not resize it in the lerp functions, and let TensorIterator handle it.

heitorschueroff added a commit that referenced this pull request Sep 14, 2020
Please refer to the discussion at the bottom of #43541 about the bug.

Differential Revision: [D23655403](https://our.internmc.facebook.com/intern/diff/D23655403)

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Sep 14, 2020
Summary:
Pull Request resolved: #44559

Please refer to the discussion at the bottom of #43541 about the bug.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D23655403

Pulled By: heitorschueroff

fbshipit-source-id: 10e4ce5c2fe7bf6e95bcfac4033202430292b03f
xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
Summary: Pull Request resolved: #43541

Reviewed By: zou3519

Differential Revision: D23499592

Pulled By: ezyang

fbshipit-source-id: 9efdd6cbf0a334ec035ddd467667ba874b892549
xuzhao9 pushed a commit that referenced this pull request Sep 18, 2020
Summary:
Pull Request resolved: #44559

Please refer to the discussion at the bottom of #43541 about the bug.

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D23655403

Pulled By: heitorschueroff

fbshipit-source-id: 10e4ce5c2fe7bf6e95bcfac4033202430292b03f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: cuda Related to torch.cuda, and CUDA support in general module: half Related to float16 half-precision floats open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants