Skip to content

Commit 757e6d8

Browse files
committed
new norm
1 parent 1d399a8 commit 757e6d8

File tree

10 files changed

+181
-60
lines changed

10 files changed

+181
-60
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
307307
} else if (contraction_size == 0) {
308308
return self_or_result.zero_();
309309
}
310-
310+
311311
auto batch_items_contiguous_or_transposed = [&](const Tensor& t) {
312312
return (t.stride(2) == 1 && t.stride(1) == t.size(2))
313313
|| (t.stride(1) == 1 && t.stride(2) == t.size(1));
@@ -536,5 +536,55 @@ Tensor matrix_power(const Tensor& a, int64_t n) {
536536
return result;
537537
}
538538

539+
Tensor frobenius_norm(const Tensor& self) {
540+
return at::norm(self);
541+
}
542+
543+
Tensor frobenius_norm(const Tensor& self, IntList dim, bool keepdim) {
544+
AT_CHECK(
545+
dim.size() <= 2,
546+
"Expected at most 2 dimensions, but got ",
547+
dim.size(),
548+
" dimensions instead.");
549+
if (dim.size() == 1) {
550+
return at::norm(self, 2, dim[0], keepdim);
551+
}
552+
return at::sqrt(at::sum(self * self, dim, keepdim));
553+
}
554+
555+
Tensor &frobenius_norm_out(
556+
Tensor& result,
557+
const Tensor& self,
558+
IntList dim,
559+
bool keepdim) {
560+
AT_CHECK(
561+
dim.size() <= 2,
562+
"Expected at most 2 dimensions, but got ",
563+
dim.size(),
564+
" dimensions instead.");
565+
if (dim.size() == 1) {
566+
return at::norm_out(result, self, 2, dim[0], keepdim);
567+
}
568+
return at::sqrt_out(result, at::sum(self * self, dim, keepdim));
569+
}
570+
571+
Tensor nuclear_norm(const Tensor& self, bool keepdim) {
572+
AT_CHECK(
573+
self.dim() == 2,
574+
"Expected a tensor with 2 dimensions, but got a ",
575+
self.dim(),
576+
" dimensions tensor instead.");
577+
return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
578+
}
579+
580+
Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
581+
AT_CHECK(
582+
self.dim() == 2,
583+
"Expected a tensor with 2 dimensions, but got a ",
584+
self.dim(),
585+
" dimensions tensor instead.");
586+
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
587+
}
588+
539589
} // namespace native
540590
} // namespace at

aten/src/ATen/native/native_functions.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,21 @@
18201820
python_default_init:
18211821
p: 2
18221822

1823+
- func: frobenius_norm(Tensor self) -> Tensor
1824+
variants: function
1825+
1826+
- func: frobenius_norm(Tensor self, IntList[1] dim, bool keepdim=false) -> Tensor
1827+
variants: function
1828+
1829+
- func: frobenius_norm_out(Tensor result, Tensor self, IntList[1] dim, bool keepdim=false) -> Tensor
1830+
variants: function
1831+
1832+
- func: nuclear_norm(Tensor self, bool keepdim=false) -> Tensor
1833+
variants: function
1834+
1835+
- func: nuclear_norm_out(Tensor result, Tensor self, bool keepdim=false) -> Tensor
1836+
variants: function
1837+
18231838
- func: native_clone(Tensor self) -> Tensor
18241839
dispatch:
18251840
SparseCPU: clone_sparse

test/onnx/test_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def test_repeat_dim_overflow(self):
426426

427427
def test_norm(self):
428428
x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True)
429-
self.assertONNX(lambda x: x.norm(dim=2), (x))
429+
self.assertONNX(lambda x: x.norm(p=2, dim=2), (x))
430430

431431
def test_upsample(self):
432432
x = Variable(torch.randn(1, 2, 3, 4), requires_grad=True)

test/test_autograd.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2997,16 +2997,28 @@ class dont_convert(tuple):
29972997
('zero_', (), NO_ARGS, 'scalar'),
29982998
('logsumexp', (S, S), (1,)),
29992999
('logsumexp', (), (0,), 'scalar'),
3000-
('norm', (S, S), (2,)),
3000+
('norm', (S, S), (), 'default'),
3001+
('norm', (S, S), (2,), '2'),
30013002
('norm', (S, S), (0,), '0'),
30023003
('norm', (S, S), (0.5,), '0_5'),
30033004
('norm', (S, S), (1,), '1'),
30043005
('norm', (S, S), (3,), '3'),
30053006
('norm', (S, S), (inf,), 'inf'),
3007+
('norm', (S, S), ('fro',), 'fro_default'),
3008+
('norm', (S, S), ('fro', [0, 1],), 'fro'),
3009+
('norm', (S, S), ('nuc',), 'nuc'),
30063010
('norm', (S, S), (-1,), 'neg_1'),
3011+
('norm', (S, S), (-2,), 'neg_2'),
30073012
('norm', (S, S), (-0.5,), 'neg_0_5'),
30083013
('norm', (S, S), (-1.5,), 'neg_1_5'),
3009-
('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5'),
3014+
('norm', (S, S), (-2, 1,), 'neg_2_2_dim', [1]),
3015+
('norm', (S, S), (-1, 1,), 'neg_1_2_dim', [1]),
3016+
('norm', (S, S), (0, 1,), '0_2_dim', [1]),
3017+
('norm', (S, S), (1, 1,), '1_2_dim', [1]),
3018+
('norm', (S, S), (2, 1,), '2_2_dim', [1]),
3019+
('norm', (S, S), (3, 1,), '3_2_dim', [1]),
3020+
('norm', (S, S), (inf, 1,), 'inf_2_dim'),
3021+
('norm', torch.rand(S, S, S) + 5e-2, (1.5,), '1_5_default'),
30103022
('norm', (S, S, S), (2, 1), '2_dim', [1]),
30113023
('norm', (S, S, S), (3, 1), '3_dim', [1]),
30123024
('norm', torch.rand(S, S, S) + 5e-2, (1.5, 1), '1_5_dim', [1]),

test/test_jit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7556,6 +7556,12 @@ def forward(self, x, y):
75567556
'test_var_dim_1d',
75577557
'test_var_dim_1d_neg0',
75587558
'test_var_dim_neg0',
7559+
'test_norm_inf',
7560+
'test_norm_inf_2_dim',
7561+
'test_norm_fro',
7562+
'test_norm_fro_default',
7563+
'test_norm_nuc',
7564+
'test_renorm_norm_inf',
75597565
'test_matrix_power_n=-1', # involves inverse
75607566
'test_matrix_power_n=-3', # involves inverse
75617567
# skipped nn functional tests

test/test_torch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def _test_norm(self, device):
842842
res = x.norm(p).item()
843843
expected = np.linalg.norm(xn, p)
844844
self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p))
845+
845846
# one dimension
846847
x = torch.randn(5, 5, device=device)
847848
xn = x.cpu().numpy()
@@ -851,6 +852,13 @@ def _test_norm(self, device):
851852
self.assertEqual(res.shape, expected.shape)
852853
self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
853854

855+
# matrix norm
856+
for p in ['fro', 'nuc']:
857+
res = x.norm(p).cpu().numpy()
858+
expected = np.linalg.norm(xn, p)
859+
self.assertEqual(res.shape, expected.shape)
860+
self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
861+
854862
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
855863
def test_norm(self):
856864
self._test_norm(self, device='cpu')

torch/_torch_docs.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3059,60 +3059,6 @@ def parse_kwargs(desc):
30593059
[ 3, 3]])
30603060
""")
30613061

3062-
add_docstr(torch.norm,
3063-
r"""
3064-
.. function:: norm(input, p=2) -> Tensor
3065-
3066-
Returns the p-norm of the :attr:`input` tensor.
3067-
3068-
.. math::
3069-
||x||_{p} = \sqrt[p]{x_{1}^{p} + x_{2}^{p} + \ldots + x_{N}^{p}}
3070-
3071-
Args:
3072-
input (Tensor): the input tensor
3073-
p (float, optional): the exponent value in the norm formulation
3074-
Example::
3075-
3076-
>>> a = torch.randn(1, 3)
3077-
>>> a
3078-
tensor([[-0.5192, -1.0782, -1.0448]])
3079-
>>> torch.norm(a, 3)
3080-
tensor(1.3633)
3081-
3082-
.. function:: norm(input, p, dim, keepdim=False, out=None) -> Tensor
3083-
3084-
Returns the p-norm of each row of the :attr:`input` tensor in the given
3085-
dimension :attr:`dim`.
3086-
3087-
If :attr:`keepdim` is ``True``, the output tensor is of the same size as
3088-
:attr:`input` except in the dimension :attr:`dim` where it is of size 1.
3089-
Otherwise, :attr:`dim` is squeezed (see :func:`torch.squeeze`), resulting
3090-
in the output tensor having 1 fewer dimension than :attr:`input`.
3091-
3092-
Args:
3093-
input (Tensor): the input tensor
3094-
p (float): the exponent value in the norm formulation
3095-
dim (int): the dimension to reduce
3096-
keepdim (bool): whether the output tensor has :attr:`dim` retained or not
3097-
out (Tensor, optional): the output tensor
3098-
3099-
Example::
3100-
3101-
>>> a = torch.randn(4, 2)
3102-
>>> a
3103-
tensor([[ 2.1983, 0.4141],
3104-
[ 0.8734, 1.9710],
3105-
[-0.7778, 0.7938],
3106-
[-0.1342, 0.7347]])
3107-
>>> torch.norm(a, 2, 1)
3108-
tensor([ 2.2369, 2.1558, 1.1113, 0.7469])
3109-
>>> torch.norm(a, 0, 1, True)
3110-
tensor([[ 2.],
3111-
[ 2.],
3112-
[ 2.],
3113-
[ 2.]])
3114-
""")
3115-
31163062
add_docstr(torch.normal,
31173063
r"""
31183064
.. function:: normal(mean, std, out=None) -> Tensor

torch/autograd/gradcheck.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def make_jacobian(input, num_out):
2323
if not input.requires_grad:
2424
return None
2525
return torch.zeros(input.nelement(), num_out, dtype=input.dtype)
26-
elif isinstance(input, container_abcs.Iterable):
26+
elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
2727
jacobians = list(filter(
2828
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
2929
if not jacobians:
@@ -37,7 +37,7 @@ def iter_tensors(x, only_requiring_grad=False):
3737
if isinstance(x, torch.Tensor):
3838
if x.requires_grad or not only_requiring_grad:
3939
yield x
40-
elif isinstance(x, container_abcs.Iterable):
40+
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
4141
for elem in x:
4242
for result in iter_tensors(elem, only_requiring_grad):
4343
yield result

torch/functional.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch._six import inf
44
from operator import mul
55
from functools import reduce
6+
from collections import Iterable
67
import math
78

89
__all__ = [
@@ -16,6 +17,7 @@
1617
'isfinite',
1718
'isinf',
1819
'isnan',
20+
'norm',
1921
'meshgrid',
2022
'split',
2123
'stft',
@@ -637,3 +639,81 @@ def argsort(input, dim=None, descending=False):
637639
if dim is None:
638640
return torch.sort(input, -1, descending)[1]
639641
return torch.sort(input, dim, descending)[1]
642+
643+
644+
def norm(input, p="fro", dim=None, keepdim=False, out=None):
645+
r"""Returns the matrix norm or vector norm of a given tensor.
646+
647+
Args:
648+
input (Tensor): the input tensor
649+
p ({int, float, inf, -inf, 'fro', 'nuc'}): the order of norm
650+
The following norms can be calculated:
651+
===== ============================ ==========================
652+
ord matrix norm vector norm
653+
===== ============================ ==========================
654+
None Frobenius norm 2-norm
655+
'fro' Frobenius norm --
656+
'nuc' nuclear norm --
657+
Other as vec norm when dim is None sum(abs(x)**ord)**(1./ord)
658+
===== ============================ ==========================
659+
dim ({int, 2-tuple of ints, 2-list of ints}, optional): If it is an int,
660+
vector norm will be calculated, if it is 2-tuple of ints, matrix norm
661+
will be calculated. If the value is None, matrix norm will be calculated
662+
when the input tensor only has two dimensions, vector norm will be
663+
calculated when the input tensor only has one dimension. If the input
664+
tensor has more than two dimensions, the vector norm will be applied to
665+
last dimension.
666+
keepdim (bool): whether the output tensors have :attr:`dim`
667+
retained or not. Ignored if attr:`dim`=``None`` and
668+
:attr:`out`=``None``.
669+
out (Tensor, optional): the output tensor. Ignored if
670+
attr:`dim`=``None`` and :attr:`out`=``None``.
671+
672+
Example::
673+
>>> import torch
674+
>>> a = torch.arange(9, dtype= torch.float) - 4
675+
>>> b = a.reshape((3, 3))
676+
>>> torch.norm(a)
677+
tensor(7.7460)
678+
>>> torch.norm(b)
679+
tensor(7.7460)
680+
>>> torch.norm(a, float('inf'))
681+
tensor(4.)
682+
>>> torch.norm(b, float('inf'))
683+
tensor([4., 3., 4.])
684+
>>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float)
685+
>>> torch.norm(c, dim=0)
686+
tensor([1.4142, 2.2361, 5.0000])
687+
>>> torch.norm(c, dim=1)
688+
tensor([3.7417, 4.2426])
689+
>>> torch.norm(c, p=1, dim=1)
690+
tensor([6., 6.])
691+
>>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2)
692+
>>> torch.norm(d, dim=(1,2))
693+
tensor([ 3.7417, 11.2250])
694+
>>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
695+
(tensor(3.7417), tensor(11.2250))
696+
"""
697+
ndim = input.dim()
698+
699+
# catch default case
700+
if dim is None and out is None:
701+
if p == "fro":
702+
return torch._C._VariableFunctions.frobenius_norm(input)
703+
elif p != "nuc":
704+
return torch._C._VariableFunctions.norm(input, p)
705+
706+
if p == "fro":
707+
if dim is None:
708+
dim = tuple(range(ndim))
709+
if out is None:
710+
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
711+
return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
712+
elif p == "nuc":
713+
if out is None:
714+
torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
715+
return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
716+
else:
717+
if out is None:
718+
return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
719+
return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)

torch/tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def argsort(self, dim=None, descending=False):
241241
r"""See :func: `torch.argsort`"""
242242
return torch.argsort(self, dim, descending)
243243

244+
def norm(self, p="fro", dim=None, keepdim=False):
245+
r"""See :func: `torch.norm`"""
246+
return torch.norm(self, p, dim, keepdim)
247+
244248
def btrifact(self, info=None, pivot=True):
245249
r"""See :func:`torch.btrifact`
246250
"""

0 commit comments

Comments
 (0)