Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 25, 2018

So here is an implementation of infinity norm for norm and renorm, also doing backwards.
I also added tests against numpy / for the gradient.

Best regards

Thomas

@elanmart
Copy link
Contributor

numpy also supports -inf, perhaps it would be nice to add it to this PR as well?

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems legit.

@ezyang
Copy link
Contributor

ezyang commented Apr 26, 2018

CI error looks real though.

@apaszke
Copy link
Contributor

apaszke commented Apr 26, 2018

@pytorchbot retest this please

@apaszke apaszke merged commit f98b778 into pytorch:master Apr 26, 2018
@apaszke
Copy link
Contributor

apaszke commented Apr 26, 2018

Thanks @t-vi!

facebook-github-bot pushed a commit that referenced this pull request Oct 18, 2018
Summary:
I found a bug in norm() and fixed it (and added tests to make sure it's fixed)
here is how to reproduce it:
```python
import torch
x = torch.FloatTensor([[10, 12, 13], [4, 0, 12]])
print(torch.norm(x, -40, dim=0, keepdim=True)) #output is tensor([[ 4.0000,  0.0000, 11.9853]])
print(torch.norm(x, float('-inf'), dim=0, keepdim=True)) #output is tensor([[1., 1., 1.]]) which is wrong!
from numpy.linalg import norm as np_norm
x = x.numpy()
print(np_norm(x, ord=-40, axis=0)) #output is array([[4., 0., 11.985261]])
print(np_norm(x, ord=float('-inf'), axis=0)) #output is array([[4., 0., 12.0]])
```
it's related to [#6817](#6817) and [#6969](#6969)
Pull Request resolved: #12722

Differential Revision: D10427687

Pulled By: soumith

fbshipit-source-id: 936a7491d1e2625410513ee9c39f8c910e8e6803
zdevito pushed a commit to zdevito/ATen that referenced this pull request Oct 18, 2018
Summary:
I found a bug in norm() and fixed it (and added tests to make sure it's fixed)
here is how to reproduce it:
```python
import torch
x = torch.FloatTensor([[10, 12, 13], [4, 0, 12]])
print(torch.norm(x, -40, dim=0, keepdim=True)) #output is tensor([[ 4.0000,  0.0000, 11.9853]])
print(torch.norm(x, float('-inf'), dim=0, keepdim=True)) #output is tensor([[1., 1., 1.]]) which is wrong!
from numpy.linalg import norm as np_norm
x = x.numpy()
print(np_norm(x, ord=-40, axis=0)) #output is array([[4., 0., 11.985261]])
print(np_norm(x, ord=float('-inf'), axis=0)) #output is array([[4., 0., 12.0]])
```
it's related to [#6817](pytorch/pytorch#6817) and [#6969](pytorch/pytorch#6969)
Pull Request resolved: pytorch/pytorch#12722

Differential Revision: D10427687

Pulled By: soumith

fbshipit-source-id: 936a7491d1e2625410513ee9c39f8c910e8e6803
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants