Skip to content

Commit f98b778

Browse files
t-viapaszke
authored andcommitted
Fix forward and backward for norm/renorm with infty norm (fixes #6817) (#6969)
1 parent 24d0566 commit f98b778

File tree

6 files changed

+134
-19
lines changed

6 files changed

+134
-19
lines changed

aten/src/TH/generic/THTensorMath.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4252,6 +4252,9 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
42524252
} else if (value == 3) {
42534253
DIM_REDUCE(sum += TH_MATH_NAME(fabs)(t_data[i*t_stride] * t_data[i*t_stride] * t_data[i*t_stride]),
42544254
*r__data = TH_MATH_NAME(pow)(sum, 1.0/3));
4255+
} else if (value == INFINITY) {
4256+
DIM_REDUCE(sum = THMax(sum, TH_MATH_NAME(fabs)(t_data[i*t_stride])),
4257+
*r__data = sum);
42554258
} else {
42564259
DIM_REDUCE(sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(t_data[i*t_stride]), value),
42574260
*r__data = TH_MATH_NAME(pow)(sum, 1.0/value));
@@ -4278,6 +4281,9 @@ accreal THTensor_(normall)(THTensor *tensor, real value)
42784281
} else if(value == 3) {
42794282
TH_TENSOR_APPLY(real, tensor, accreal z = *tensor_data; sum += std::abs(z*z*z););
42804283
return TH_MATH_NAME(pow)(sum, 1.0/3);
4284+
} else if(value == INFINITY) {
4285+
TH_TENSOR_APPLY(real, tensor, sum = THMax(sum, TH_MATH_NAME(fabs)(*tensor_data)););
4286+
return sum;
42814287
} else {
42824288
TH_TENSOR_APPLY(real, tensor, sum += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*tensor_data), value););
42834289
return TH_MATH_NAME(pow)(sum, 1.0/value);
@@ -4311,11 +4317,15 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, real value, int dimension,
43114317
TH_TENSOR_APPLY(real, rowS, norm += fabs(*rowS_data););
43124318
} else if (value == 2) {
43134319
TH_TENSOR_APPLY(real, rowS, accreal z = *rowS_data; norm += z*z;);
4320+
} else if (value == INFINITY) {
4321+
TH_TENSOR_APPLY(real, rowS, norm = THMax(norm, TH_MATH_NAME(fabs)(*rowS_data)););
43144322
} else {
43154323
TH_TENSOR_APPLY(real, rowS, norm += TH_MATH_NAME(pow)(TH_MATH_NAME(fabs)(*rowS_data), value););
43164324
}
43174325

4318-
norm = pow(norm, 1/value);
4326+
if (value != INFINITY) {
4327+
norm = pow(norm, 1/value);
4328+
}
43194329

43204330
if (norm > maxnorm)
43214331
{

aten/src/THC/THCTensorMathReduce.cuh

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ struct ReduceMax {
116116
}
117117
};
118118

119+
template <typename InT, typename AccT>
120+
struct ReduceMaxTo {
121+
inline __device__ AccT operator()(InT a, InT b) const {
122+
return ScalarConvert<InT, AccT>::to(THCNumerics<InT>::gt(a, b) ? a : b);
123+
}
124+
};
125+
126+
#ifdef CUDA_HALF_TENSOR
127+
template <>
128+
struct ReduceMaxTo<half, float> {
129+
inline __device__ float operator()(float a, half b) const {
130+
float b_f = __half2float(b);
131+
return (THCNumerics<float>::gt(a, b_f) ? a : b_f);
132+
}
133+
};
134+
#endif // CUDA_HALF_TENSOR
135+
119136
struct LogicalAll {
120137
inline __device__ unsigned char operator()(unsigned char x,
121138
unsigned char y) const {
@@ -130,6 +147,11 @@ struct LogicalAny {
130147
}
131148
};
132149

150+
template<typename Real>
151+
inline __device__ Real THCMax(const Real a, const Real b) {
152+
return THCNumerics<Real>::gt(a, b) ? a : b;
153+
}
154+
133155
template<typename Real>
134156
__global__ void THCTensor_kernel_renorm(Real *data, const Real value, const ptrdiff_t size, const Real maxnorm)
135157
{
@@ -140,27 +162,50 @@ __global__ void THCTensor_kernel_renorm(Real *data, const Real value, const ptrd
140162
Real *row = data + size*bx;
141163

142164
buffer[tx] = ScalarConvert<int, Real>::to(0);
165+
Real norm;
143166

144-
// get norm of axis
145-
for (ptrdiff_t i=tx; i<size; i+=step)
146-
{
147-
buffer[tx] = THCNumerics<Real>::add(
148-
buffer[tx],
149-
THCNumerics<Real>::pow(
150-
THCNumerics<Real>::abs(row[i]),
151-
value)
152-
);
153-
}
154-
// add (reduce)
155-
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
156-
{
167+
if (THCNumerics<Real>::eq(value, ScalarConvert<float, Real>::to(INFINITY))) {
168+
// get norm of axis
169+
for (ptrdiff_t i=tx; i<size; i+=step)
170+
{
171+
buffer[tx] = THCMax<Real>(
172+
buffer[tx],
173+
THCNumerics<Real>::abs(row[i])
174+
);
175+
}
176+
// add (reduce)
177+
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
178+
{
179+
__syncthreads();
180+
if (tx < stride)
181+
buffer[tx] = THCMax<Real>(buffer[tx], buffer[tx+stride]);
182+
}
183+
// clip norms
157184
__syncthreads();
158-
if (tx < stride)
159-
buffer[tx] = THCNumerics<Real>::add(buffer[tx], buffer[tx+stride]);
185+
norm = buffer[0];
186+
} else {
187+
// get norm of axis
188+
for (ptrdiff_t i=tx; i<size; i+=step)
189+
{
190+
buffer[tx] = THCNumerics<Real>::add(
191+
buffer[tx],
192+
THCNumerics<Real>::pow(
193+
THCNumerics<Real>::abs(row[i]),
194+
value)
195+
);
196+
}
197+
// add (reduce)
198+
for (unsigned int stride = blockDim.x >> 1; stride > 0; stride >>= 1)
199+
{
200+
__syncthreads();
201+
if (tx < stride)
202+
buffer[tx] = THCNumerics<Real>::add(buffer[tx], buffer[tx+stride]);
203+
}
204+
// clip norms
205+
__syncthreads();
206+
norm = THCNumerics<Real>::pow(buffer[0], THCNumerics<Real>::cinv(value));
160207
}
161-
// clip norms
162-
__syncthreads();
163-
Real norm = THCNumerics<Real>::pow(buffer[0], THCNumerics<Real>::cinv(value));
208+
164209
if (THCNumerics<Real>::gt(norm, maxnorm))
165210
{
166211
norm = THCNumerics<Real>::div(

aten/src/THC/generic/THCTensorMathReduce.cu

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ THCTensor_(norm)(THCState *state, THCTensor* self, THCTensor* src, real value, i
182182
ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
183183
THCTensor_(pow)(state, self, self, ScalarConvert<float, real>::to(0.5));
184184

185+
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(INFINITY))) {
186+
THC_reduceDim(state, self, src,
187+
TensorNormOp<real, 1>(value), ReduceMaxTo<real, accreal>(), ReduceMax<accreal>(),
188+
ScalarConvert<float, accreal>::to(0.0), dimension, keepdim);
185189
} else {
186190
THC_reduceDim(state, self, src,
187191
TensorNormOp<real, -1>(value), ReduceAdd<real, accreal>(), ReduceAdd<accreal, accreal>(),
@@ -220,6 +224,13 @@ THCTensor_(normall)(THCState *state, THCTensor *self, real value)
220224
ScalarConvert<float, accreal>::to(0.0f),
221225
&result, 0);
222226
result = THCNumerics<accreal>::sqrt(result);
227+
} else if (THCNumerics<real>::eq(value, ScalarConvert<float, real>::to(INFINITY))) {
228+
THC_reduceAll(state, self,
229+
TensorNormOp<real, 1>(value),
230+
ReduceMaxTo<real, accreal>(),
231+
ReduceMax<accreal>(),
232+
ScalarConvert<float, accreal>::to(0.0f),
233+
&result, 0);
223234
} else {
224235
THC_reduceAll(state, self,
225236
TensorNormOp<real, -1>(value),

test/test_autograd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,6 +2531,7 @@ class dont_convert(tuple):
25312531
('std', (S,), (0, True, True), 'keepdim_dim_1d', [0]),
25322532
('renorm', (S, S, S), (2, 1, 0.5), 'dim', [1]),
25332533
('renorm', (S, S, S), (1, 2, 3), 'norm_1'),
2534+
('renorm', (S, S, S), (float('inf'), 2, 0.5), 'norm_inf'),
25342535
('repeat', (S,), (2,), 'single_number'),
25352536
('repeat', (), (2, 3), 'scalar'),
25362537
('repeat', (2, 2), (3, 2)),
@@ -2619,6 +2620,7 @@ class dont_convert(tuple):
26192620
('norm', (S, S), (0.5,), '0_5'),
26202621
('norm', (S, S), (1,), '1'),
26212622
('norm', (S, S), (3,), '3'),
2623+
('norm', (S, S), (float('inf'),), 'inf'),
26222624
('norm', (S, S), (-1,), 'neg_1'),
26232625
('norm', (S, S), (-0.5,), 'neg_0_5'),
26242626
('norm', (S, S), (-1.5,), 'neg_1_5'),

test/test_torch.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,33 @@ def test_max(self):
555555
def test_min(self):
556556
self._testSelection(torch.min, min)
557557

558+
@staticmethod
559+
def _test_norm(self, device):
560+
# full reduction
561+
x = torch.randn(5, device=device)
562+
xn = x.cpu().numpy()
563+
for p in [0, 1, 2, 3, 4, float('inf')]:
564+
res = x.norm(p).item()
565+
expected = np.linalg.norm(xn, p)
566+
self.assertEqual(res, expected, "full reduction failed for {}-norm".format(p))
567+
# one dimension
568+
x = torch.randn(5, 5, device=device)
569+
xn = x.cpu().numpy()
570+
for p in [0, 1, 2, 3, 4, float('inf')]:
571+
res = x.norm(p, 1).cpu().numpy()
572+
expected = np.linalg.norm(xn, p, 1)
573+
self.assertEqual(res.shape, expected.shape)
574+
self.assertTrue(np.allclose(res, expected), "dim reduction failed for {}-norm".format(p))
575+
576+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
577+
def test_norm(self):
578+
self._test_norm(self, device='cpu')
579+
580+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
581+
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
582+
def test_norm_cuda(self):
583+
self._test_norm(self, device='cuda')
584+
558585
def test_dim_reduction_uint8_overflow(self):
559586
example = [[-1, 2, 1], [5, 3, 6]]
560587
x = torch.tensor(example, dtype=torch.uint8)
@@ -2056,6 +2083,23 @@ def renorm(matrix, value, dim, max_norm):
20562083
self.assertEqual(m3, m2)
20572084
self.assertEqual(m3.norm(2, 0), m2.norm(2, 0))
20582085

2086+
@staticmethod
2087+
def _test_renorm_ps(self, device):
2088+
# full reduction
2089+
x = torch.randn(5, 5)
2090+
xn = x.numpy()
2091+
for p in [1, 2, 3, 4, float('inf')]:
2092+
res = x.renorm(p, 1, 1)
2093+
expected = x / x.norm(p, 0, keepdim=True).clamp(min=1)
2094+
self.assertEqual(res.numpy(), expected.numpy(), "renorm failed for {}-norm".format(p))
2095+
2096+
def test_renorm_ps(self):
2097+
self._test_renorm_ps(self, device='cpu')
2098+
2099+
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
2100+
def test_renorm_ps_cuda(self):
2101+
self._test_renorm_ps(self, device='cuda')
2102+
20592103
@staticmethod
20602104
def _test_multinomial(self, type):
20612105
def make_prob_dist(shape, is_contiguous):

tools/autograd/templates/Functions.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ Tensor norm_backward(const Tensor & grad, const Tensor & self, const Scalar & p_
8686
} else if (p == 2.0) {
8787
self_scaled = self;
8888
scale_v = grad / norm;
89+
} else if (p == INFINITY) {
90+
self_scaled = self.sign() * (self.abs() == norm).toType(self.type());
91+
scale_v = grad.clone();
8992
} else {
9093
self_scaled = self * self.abs().pow(p - 2);
9194
scale_v = grad / norm.pow(p - 1);

0 commit comments

Comments
 (0)