Skip to content

Commit 237c27c

Browse files
zou3519soumith
authored andcommitted
Fix reduction functions not respecting the strides of output when output is correct size (#4995)
1 parent 8056399 commit 237c27c

File tree

9 files changed

+132
-12
lines changed

9 files changed

+132
-12
lines changed

aten/src/TH/generic/THTensorMath.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2226,13 +2226,33 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
22262226
return THTensor_(nElement)(t);
22272227
}
22282228

2229+
2230+
// Helper function to be used in a reduction operation.
2231+
// Due to resize semantics of outputs, if the specified output tensor r_ has
2232+
// same size as the output of the reduction operation, then any noncontiguities
2233+
// in r_ should be preserved.
2234+
// The reduction operation, however, needs to act on r_ with an extra dimension
2235+
// (the reduced dimension), so this function "resizes" r_ and preserves its
2236+
// noncontiguities if necessary.
2237+
void THTensor_(preserveReduceDimSemantics)(
2238+
THTensor *r_, int in_dims, int reduce_dimension, int keepdim) {
2239+
if (r_ && !keepdim &&
2240+
THTensor_(nDimension)(r_) == in_dims - 1 &&
2241+
THTensor_(nDimension)(r_) != 0) {
2242+
THTensor_(unsqueeze1d)(r_, r_, reduce_dimension);
2243+
}
2244+
}
2245+
22292246
void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
22302247
{
22312248
THLongStorage *dim;
22322249

22332250
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
22342251
dimension + TH_INDEX_BASE);
22352252

2253+
int in_dims = THTensor_(nDimension)(t);
2254+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
2255+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
22362256
dim = THTensor_(newSizeOf)(t);
22372257
THLongStorage_set(dim, dimension, 1);
22382258
THTensor_(resize)(values_, dim, NULL);
@@ -2314,6 +2334,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
23142334
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
23152335
dimension + TH_INDEX_BASE);
23162336

2337+
int in_dims = THTensor_(nDimension)(t);
2338+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
2339+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
23172340
dim = THTensor_(newSizeOf)(t);
23182341
THLongStorage_set(dim, dimension, 1);
23192342
THTensor_(resize)(values_, dim, NULL);
@@ -2395,6 +2418,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
23952418
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
23962419
dimension + TH_INDEX_BASE);
23972420

2421+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
23982422
dim = THTensor_(newSizeOf)(t);
23992423
THLongStorage_set(dim, dimension, 1);
24002424
THTensor_(resize)(r_, dim, NULL);
@@ -2474,6 +2498,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
24742498
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
24752499
dimension + TH_INDEX_BASE);
24762500

2501+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
24772502
dim = THTensor_(newSizeOf)(t);
24782503
THLongStorage_set(dim, dimension, 1);
24792504
THTensor_(resize)(r_, dim, NULL);
@@ -3197,6 +3222,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
31973222

31983223
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
31993224

3225+
int in_dims = THTensor_(nDimension)(t);
3226+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
3227+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
32003228
dim = THTensor_(newSizeOf)(t);
32013229
THLongStorage_set(dim, dimension, 1);
32023230
THTensor_(resize)(values_, dim, NULL);
@@ -3263,6 +3291,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
32633291
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
32643292
THArgCheck(k > 0 && k <= t->size[dimension], 2, "selected index out of range");
32653293

3294+
int in_dims = THTensor_(nDimension)(t);
3295+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
3296+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
32663297
dim = THTensor_(newSizeOf)(t);
32673298
THLongStorage_set(dim, dimension, 1);
32683299
THTensor_(resize)(values_, dim, NULL);
@@ -3778,6 +3809,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
37783809
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
37793810
dimension + TH_INDEX_BASE);
37803811

3812+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
37813813
dim = THTensor_(newSizeOf)(t);
37823814
THLongStorage_set(dim, dimension, 1);
37833815
THTensor_(resize)(r_, dim, NULL);
@@ -3821,6 +3853,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
38213853
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
38223854
dimension + TH_INDEX_BASE);
38233855

3856+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
38243857
dim = THTensor_(newSizeOf)(t);
38253858
THLongStorage_set(dim, dimension, 1);
38263859
THTensor_(resize)(r_, dim, NULL);
@@ -3864,6 +3897,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
38643897
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
38653898
dimension + TH_INDEX_BASE);
38663899

3900+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
38673901
dim = THTensor_(newSizeOf)(t);
38683902
THLongStorage_set(dim, dimension, 1);
38693903
THTensor_(resize)(r_, dim, NULL);

aten/src/TH/generic/THTensorMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha,
7575
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);
7676

7777
TH_API ptrdiff_t THTensor_(numel)(THTensor *t);
78+
void THTensor_(preserveReduceDimSemantics)(THTensor *r_, int in_dims, int reduce_dimension, int keepdim);
7879
TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
7980
TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
8081
TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim);

aten/src/THC/THCReduce.cuh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,14 @@ bool THC_reduceDim(THCState* state,
325325

326326
}
327327
}
328-
// Resize out to correspond to the reduced size
328+
329+
// Resize out to correspond to the reduced size with keepdim=True.
330+
331+
// Preserve noncontiguities by unsqueezing out if necessary
332+
TensorUtils<TensorType>::preserveReduceDimSemantics(
333+
state, out, TensorUtils<TensorType>::getDims(state, in), dim, keepdim);
334+
335+
// Resize out
329336
THLongStorage* sizes = TensorUtils<TensorType>::newSizeOf(state, in);
330337
THLongStorage_set(sizes, dim, 1);
331338
TensorUtils<TensorType>::resize(state, out, sizes, NULL);

aten/src/THC/THCTensorMathReduce.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "THCNumerics.cuh"
77
#include "THCReduce.cuh"
88
#include "THCReduceAll.cuh"
9+
#include "THCTensorCopy.h"
910
#include "THCThrustAllocator.cuh"
1011
#include <thrust/functional.h>
1112
#include <thrust/device_ptr.h>
@@ -704,6 +705,15 @@ THC_reduceDimIndex(THCState *state,
704705
dimension < TensorUtils<TensorTypeK>::getDims(state, src),
705706
3, "dimension out of range");
706707

708+
709+
// Unsqueeze tgt1_/tgt_2 if necessary so that their contiguity traits
710+
// are preserved if they are the same size as the correct reduction output.
711+
int src_dims = TensorUtils<TensorTypeK>::getDims(state, src);
712+
TensorUtils<TensorTypeK>::preserveReduceDimSemantics(
713+
state, tgt1_, src_dims, dimension, keepdim);
714+
TensorUtils<TensorTypeIndex>::preserveReduceDimSemantics(
715+
state, tgt2_, src_dims, dimension, keepdim);
716+
707717
THLongStorage *dim = TensorUtils<TensorTypeK>::newSizeOf(state, src);
708718
THLongStorage_set(dim, dimension, 1);
709719
TensorUtils<TensorTypeK>::resize(state, tgt1_, dim, NULL);

aten/src/THC/THCTensorTypeUtils.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ TensorUtils<TENSOR_TYPE>::squeeze1d(THCState *state, \
8181
TENSOR_TYPE##_squeeze1d(state, dst, src, dimension); \
8282
} \
8383
\
84+
void \
85+
TensorUtils<TENSOR_TYPE>::unsqueeze1d(THCState *state, \
86+
TENSOR_TYPE *dst, \
87+
TENSOR_TYPE *src, \
88+
int dimension) { \
89+
TENSOR_TYPE##_unsqueeze1d(state, dst, src, dimension); \
90+
} \
91+
\
8492
DATA_TYPE* \
8593
TensorUtils<TENSOR_TYPE>::getData(THCState* state, \
8694
TENSOR_TYPE* t) { \
@@ -133,6 +141,24 @@ TensorUtils<TENSOR_TYPE>::allContiguous(THCState* state, \
133141
return true; \
134142
} \
135143
\
144+
/* Due to the resize semantics of ops with `out=` keywords, if */ \
145+
/* the output `tensor` has the same shape as the output of the */ \
146+
/* reduction operation, then any noncontiguities in the output */ \
147+
/* `tensor` should be preserved. This needs to be special cased b/c */ \
148+
/* otherwise, when keepdim=False, the implementations of reduction */ \
149+
/* ops resize `tensor` to the reduced size with keepdim=True, and */ \
150+
/* then later squeeze `tensor` to the correct output size, breaking */ \
151+
/* the contiguity guarantees of the resize semantics. */ \
152+
void \
153+
TensorUtils<TENSOR_TYPE>::preserveReduceDimSemantics( \
154+
THCState *state, TENSOR_TYPE *tensor, \
155+
int in_dims, int64_t dimension, int keepdim) {\
156+
int out_dims = TensorUtils<TENSOR_TYPE>::getDims(state, tensor); \
157+
if (out_dims > 0 && !keepdim && out_dims == in_dims - 1) { \
158+
TensorUtils<TENSOR_TYPE>::unsqueeze1d(state, tensor, tensor, dimension);\
159+
} \
160+
} \
161+
\
136162
int \
137163
TensorUtils<TENSOR_TYPE>::getDevice(THCState* state, \
138164
TENSOR_TYPE* t) { \

aten/src/THC/THCTensorTypeUtils.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ struct TensorUtils {
5151
TENSOR_TYPE* src); \
5252
static void squeeze1d(THCState *state, TENSOR_TYPE *dst, \
5353
TENSOR_TYPE *src, int dimension); \
54+
static void unsqueeze1d(THCState *state, TENSOR_TYPE *dst, \
55+
TENSOR_TYPE *src, int dimension); \
56+
static void preserveReduceDimSemantics( \
57+
THCState *state, TENSOR_TYPE *tensor, \
58+
int in_dims, int64_t dimension, int keepdim); \
5459
static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t); \
5560
static ptrdiff_t getNumElements(THCState* state, TENSOR_TYPE* t); \
5661
static int64_t getSize(THCState* state, TENSOR_TYPE* t, int dim); \

aten/src/THC/generic/THCTensorMathReduce.cu

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ THC_API void
7777
THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
7878
{
7979
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
80+
81+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
82+
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
8083
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
8184
THLongStorage_set(dim, dimension, 1);
8285
THCTensor_(resize)(state, self_, dim, NULL);
@@ -103,6 +106,9 @@ THC_API void
103106
THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
104107
{
105108
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
109+
110+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
111+
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
106112
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
107113
THLongStorage_set(dim, dimension, 1);
108114
THCTensor_(resize)(state, self_, dim, NULL);
@@ -383,17 +389,22 @@ THCTensor_(median)(THCState *state,
383389

384390
THCTensor_(sort)(state, sorted, sorted_indices, self, dimension, 0);
385391

386-
THCTensor_(narrow)(state, values, sorted, dimension, k, 1);
387-
THCudaLongTensor_narrow(state, indices, sorted_indices, dimension, k, 1);
388-
389-
THCTensor_(free)(state, sorted);
390-
THCudaLongTensor_free(state, sorted_indices);
392+
THCTensor *newValues = THCTensor_(newNarrow)(state, sorted, dimension, k, 1);
393+
THCudaLongTensor *newIndices = THCudaLongTensor_newNarrow(state, sorted_indices, dimension, k, 1);
391394

392395
if (!keepdim) {
393-
THCTensor_(squeeze1d)(state, values, values, dimension);
394-
THCudaLongTensor_squeeze1d(state, indices, indices, dimension);
396+
THCTensor_(squeeze1d)(state, newValues, newValues, dimension);
397+
THCudaLongTensor_squeeze1d(state, newIndices, newIndices, dimension);
395398
}
396399

400+
THCTensor_(resizeAs)(state, values, newValues);
401+
THCudaLongTensor_resizeAs(state, indices, newIndices);
402+
THCTensor_(copy)(state, values, newValues);
403+
THCudaLongTensor_copy(state, indices, newIndices);
404+
405+
THCTensor_(free)(state, newValues);
406+
THCudaLongTensor_free(state, newIndices);
407+
397408
THCudaCheck(cudaGetLastError());
398409
}
399410

aten/src/THC/generic/THCTensorMode.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ THC_API void THCTensor_(mode)(THCState *state,
180180

181181
// Resize output value, index Tensors to appropriate sizes (i.e. the same as
182182
// the input Tensor, except at dim=dimension, the size is 1)
183+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
184+
state, values, ndim, dimension, keepdim);
185+
TensorUtils<THCudaLongTensor>::preserveReduceDimSemantics(
186+
state, indices, ndim, dimension, keepdim);
183187
dim = THCTensor_(newSizeOf)(state, input);
184188
THLongStorage_set(dim, dimension, 1);
185189
THCTensor_(resize)(state, values, dim, NULL);

test/test_torch.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -383,17 +383,20 @@ def _test_dim_reduction(self, cast):
383383
"mean", "median", "mode", "norm", "prod",
384384
"std", "sum", "var", "max", "min"]
385385

386-
def normfn_attr(t, dim, keepdim=False):
386+
def normfn_attr(t, dim, keepdim=False, out=None):
387387
attr = getattr(torch, "norm")
388-
return attr(t, 2, dim, keepdim)
388+
return attr(t, 2, dim, keepdim, out=out)
389389

390390
for fn_name in dim_red_fns:
391391
fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
392392

393-
def fn(x, dim, keepdim=False):
394-
ans = fn_attr(x, dim, keepdim=keepdim)
393+
def fn(x, dim, keepdim=False, out=None):
394+
ans = fn_attr(x, dim, keepdim=keepdim, out=out)
395395
return ans if not isinstance(ans, tuple) else ans[0]
396396

397+
def fn_tuple(x, dim, keepdim=False, out=None):
398+
return fn_attr(x, dim, keepdim=keepdim, out=out)
399+
397400
def test_multidim(x, dim):
398401
self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
399402
self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
@@ -418,6 +421,25 @@ def test_multidim(x, dim):
418421
x = cast(torch.randn(dims))
419422
test_multidim(x, singleton_dim)
420423

424+
# check reducing with output kwargs
425+
if fn_name in ['median', 'mode', 'max', 'min']:
426+
y = cast(torch.randn(5, 3))
427+
values = cast(torch.randn(5, 3))
428+
indices = cast(torch.zeros(5, 3).long() - 1)
429+
fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
430+
values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
431+
self.assertEqual(values[:, 1], values_expected,
432+
'{} values with out= kwarg'.format(fn_name))
433+
self.assertEqual(indices[:, 1], indices_expected,
434+
'{} indices with out= kwarg'.format(fn_name))
435+
continue
436+
437+
x = cast(torch.randn(5, 3))
438+
y = cast(torch.randn(5, 3))
439+
fn(y, 1, keepdim=False, out=x[:, 1])
440+
expected = fn(y, 1, keepdim=False)
441+
self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name))
442+
421443
def test_dim_reduction(self):
422444
self._test_dim_reduction(self, lambda t: t)
423445

0 commit comments

Comments
 (0)