Skip to content

Commit 404510e

Browse files
zou3519soumith
authored andcommitted
Fix reduction functions not respecting the strides of output when output is correct size (#4995)
1 parent da3c4cb commit 404510e

File tree

9 files changed

+132
-12
lines changed

9 files changed

+132
-12
lines changed

test/test_torch.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,17 +349,20 @@ def _test_dim_reduction(self, cast):
349349
"mean", "median", "mode", "norm", "prod",
350350
"std", "sum", "var", "max", "min"]
351351

352-
def normfn_attr(t, dim, keepdim=False):
352+
def normfn_attr(t, dim, keepdim=False, out=None):
353353
attr = getattr(torch, "norm")
354-
return attr(t, 2, dim, keepdim)
354+
return attr(t, 2, dim, keepdim, out=out)
355355

356356
for fn_name in dim_red_fns:
357357
fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr
358358

359-
def fn(x, dim, keepdim=False):
360-
ans = fn_attr(x, dim, keepdim=keepdim)
359+
def fn(x, dim, keepdim=False, out=None):
360+
ans = fn_attr(x, dim, keepdim=keepdim, out=out)
361361
return ans if not isinstance(ans, tuple) else ans[0]
362362

363+
def fn_tuple(x, dim, keepdim=False, out=None):
364+
return fn_attr(x, dim, keepdim=keepdim, out=out)
365+
363366
def test_multidim(x, dim):
364367
self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
365368
self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
@@ -384,6 +387,25 @@ def test_multidim(x, dim):
384387
x = cast(torch.randn(dims))
385388
test_multidim(x, singleton_dim)
386389

390+
# check reducing with output kwargs
391+
if fn_name in ['median', 'mode', 'max', 'min']:
392+
y = cast(torch.randn(5, 3))
393+
values = cast(torch.randn(5, 3))
394+
indices = cast(torch.zeros(5, 3).long() - 1)
395+
fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
396+
values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
397+
self.assertEqual(values[:, 1], values_expected,
398+
'{} values with out= kwarg'.format(fn_name))
399+
self.assertEqual(indices[:, 1], indices_expected,
400+
'{} indices with out= kwarg'.format(fn_name))
401+
continue
402+
403+
x = cast(torch.randn(5, 3))
404+
y = cast(torch.randn(5, 3))
405+
fn(y, 1, keepdim=False, out=x[:, 1])
406+
expected = fn(y, 1, keepdim=False)
407+
self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name))
408+
387409
def test_dim_reduction(self):
388410
self._test_dim_reduction(self, lambda t: t)
389411

torch/lib/TH/generic/THTensorMath.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,13 +1711,33 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
17111711
return THTensor_(nElement)(t);
17121712
}
17131713

1714+
1715+
// Helper function to be used in a reduction operation.
1716+
// Due to resize semantics of outputs, if the specified output tensor r_ has
1717+
// same size as the output of the reduction operation, then any noncontiguities
1718+
// in r_ should be preserved.
1719+
// The reduction operation, however, needs to act on r_ with an extra dimension
1720+
// (the reduced dimension), so this function "resizes" r_ and preserves its
1721+
// noncontiguities if necessary.
1722+
void THTensor_(preserveReduceDimSemantics)(
1723+
THTensor *r_, int in_dims, int reduce_dimension, int keepdim) {
1724+
if (r_ && !keepdim &&
1725+
THTensor_(nDimension)(r_) == in_dims - 1 &&
1726+
THTensor_(nDimension)(r_) != 0) {
1727+
THTensor_(unsqueeze1d)(r_, r_, reduce_dimension);
1728+
}
1729+
}
1730+
17141731
void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
17151732
{
17161733
THLongStorage *dim;
17171734

17181735
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
17191736
dimension + TH_INDEX_BASE);
17201737

1738+
int in_dims = THTensor_(nDimension)(t);
1739+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
1740+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
17211741
dim = THTensor_(newSizeOf)(t);
17221742
THLongStorage_set(dim, dimension, 1);
17231743
THTensor_(resize)(values_, dim, NULL);
@@ -1799,6 +1819,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
17991819
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
18001820
dimension + TH_INDEX_BASE);
18011821

1822+
int in_dims = THTensor_(nDimension)(t);
1823+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
1824+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
18021825
dim = THTensor_(newSizeOf)(t);
18031826
THLongStorage_set(dim, dimension, 1);
18041827
THTensor_(resize)(values_, dim, NULL);
@@ -1881,6 +1904,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
18811904
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
18821905
dimension + TH_INDEX_BASE);
18831906

1907+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
18841908
dim = THTensor_(newSizeOf)(t);
18851909
THLongStorage_set(dim, dimension, 1);
18861910
THTensor_(resize)(r_, dim, NULL);
@@ -1917,6 +1941,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
19171941
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
19181942
dimension + TH_INDEX_BASE);
19191943

1944+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
19201945
dim = THTensor_(newSizeOf)(t);
19211946
THLongStorage_set(dim, dimension, 1);
19221947
THTensor_(resize)(r_, dim, NULL);
@@ -2597,6 +2622,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
25972622

25982623
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
25992624

2625+
int in_dims = THTensor_(nDimension)(t);
2626+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
2627+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
26002628
dim = THTensor_(newSizeOf)(t);
26012629
THLongStorage_set(dim, dimension, 1);
26022630
THTensor_(resize)(values_, dim, NULL);
@@ -2663,6 +2691,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
26632691
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
26642692
THArgCheck(k > 0 && k <= t->size[dimension], 2, "selected index out of range");
26652693

2694+
int in_dims = THTensor_(nDimension)(t);
2695+
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
2696+
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
26662697
dim = THTensor_(newSizeOf)(t);
26672698
THLongStorage_set(dim, dimension, 1);
26682699
THTensor_(resize)(values_, dim, NULL);
@@ -3151,6 +3182,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
31513182
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
31523183
dimension + TH_INDEX_BASE);
31533184

3185+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
31543186
dim = THTensor_(newSizeOf)(t);
31553187
THLongStorage_set(dim, dimension, 1);
31563188
THTensor_(resize)(r_, dim, NULL);
@@ -3194,6 +3226,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
31943226
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
31953227
dimension + TH_INDEX_BASE);
31963228

3229+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
31973230
dim = THTensor_(newSizeOf)(t);
31983231
THLongStorage_set(dim, dimension, 1);
31993232
THTensor_(resize)(r_, dim, NULL);
@@ -3237,6 +3270,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
32373270
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
32383271
dimension + TH_INDEX_BASE);
32393272

3273+
THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
32403274
dim = THTensor_(newSizeOf)(t);
32413275
THLongStorage_set(dim, dimension, 1);
32423276
THTensor_(resize)(r_, dim, NULL);

torch/lib/TH/generic/THTensorMath.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha,
7575
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);
7676

7777
TH_API ptrdiff_t THTensor_(numel)(THTensor *t);
78+
void THTensor_(preserveReduceDimSemantics)(THTensor *r_, int in_dims, int reduce_dimension, int keepdim);
7879
TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
7980
TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
8081
TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim);

torch/lib/THC/THCReduce.cuh

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,14 @@ bool THC_reduceDim(THCState* state,
325325

326326
}
327327
}
328-
// Resize out to correspond to the reduced size
328+
329+
// Resize out to correspond to the reduced size with keepdim=True.
330+
331+
// Preserve noncontiguities by unsqueezing out if necessary
332+
TensorUtils<TensorType>::preserveReduceDimSemantics(
333+
state, out, TensorUtils<TensorType>::getDims(state, in), dim, keepdim);
334+
335+
// Resize out
329336
THLongStorage* sizes = TensorUtils<TensorType>::newSizeOf(state, in);
330337
THLongStorage_set(sizes, dim, 1);
331338
TensorUtils<TensorType>::resize(state, out, sizes, NULL);

torch/lib/THC/THCTensorMathReduce.cuh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "THCNumerics.cuh"
77
#include "THCReduce.cuh"
88
#include "THCReduceAll.cuh"
9+
#include "THCTensorCopy.h"
910
#include "THCThrustAllocator.cuh"
1011
#include <thrust/functional.h>
1112
#include <thrust/device_ptr.h>
@@ -710,6 +711,15 @@ THC_reduceDimIndex(THCState *state,
710711
dimension < TensorUtils<TensorTypeK>::getDims(state, src),
711712
3, "dimension out of range");
712713

714+
715+
// Unsqueeze tgt1_/tgt_2 if necessary so that their contiguity traits
716+
// are preserved if they are the same size as the correct reduction output.
717+
int src_dims = TensorUtils<TensorTypeK>::getDims(state, src);
718+
TensorUtils<TensorTypeK>::preserveReduceDimSemantics(
719+
state, tgt1_, src_dims, dimension, keepdim);
720+
TensorUtils<TensorTypeIndex>::preserveReduceDimSemantics(
721+
state, tgt2_, src_dims, dimension, keepdim);
722+
713723
THLongStorage *dim = TensorUtils<TensorTypeK>::newSizeOf(state, src);
714724
THLongStorage_set(dim, dimension, 1);
715725
TensorUtils<TensorTypeK>::resize(state, tgt1_, dim, NULL);

torch/lib/THC/THCTensorTypeUtils.cu

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ TensorUtils<TENSOR_TYPE>::squeeze1d(THCState *state, \
8181
TENSOR_TYPE##_squeeze1d(state, dst, src, dimension); \
8282
} \
8383
\
84+
void \
85+
TensorUtils<TENSOR_TYPE>::unsqueeze1d(THCState *state, \
86+
TENSOR_TYPE *dst, \
87+
TENSOR_TYPE *src, \
88+
int dimension) { \
89+
TENSOR_TYPE##_unsqueeze1d(state, dst, src, dimension); \
90+
} \
91+
\
8492
DATA_TYPE* \
8593
TensorUtils<TENSOR_TYPE>::getData(THCState* state, \
8694
TENSOR_TYPE* t) { \
@@ -133,6 +141,24 @@ TensorUtils<TENSOR_TYPE>::allContiguous(THCState* state, \
133141
return true; \
134142
} \
135143
\
144+
/* Due to the resize semantics of ops with `out=` keywords, if */ \
145+
/* the output `tensor` has the same shape as the output of the */ \
146+
/* reduction operation, then any noncontiguities in the output */ \
147+
/* `tensor` should be preserved. This needs to be special cased b/c */ \
148+
/* otherwise, when keepdim=False, the implementations of reduction */ \
149+
/* ops resize `tensor` to the reduced size with keepdim=True, and */ \
150+
/* then later squeeze `tensor` to the correct output size, breaking */ \
151+
/* the contiguity guarantees of the resize semantics. */ \
152+
void \
153+
TensorUtils<TENSOR_TYPE>::preserveReduceDimSemantics( \
154+
THCState *state, TENSOR_TYPE *tensor, \
155+
int in_dims, int64_t dimension, int keepdim) {\
156+
int out_dims = TensorUtils<TENSOR_TYPE>::getDims(state, tensor); \
157+
if (out_dims > 0 && !keepdim && out_dims == in_dims - 1) { \
158+
TensorUtils<TENSOR_TYPE>::unsqueeze1d(state, tensor, tensor, dimension);\
159+
} \
160+
} \
161+
\
136162
int \
137163
TensorUtils<TENSOR_TYPE>::getDevice(THCState* state, \
138164
TENSOR_TYPE* t) { \

torch/lib/THC/THCTensorTypeUtils.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ struct TensorUtils {
5151
TENSOR_TYPE* src); \
5252
static void squeeze1d(THCState *state, TENSOR_TYPE *dst, \
5353
TENSOR_TYPE *src, int dimension); \
54+
static void unsqueeze1d(THCState *state, TENSOR_TYPE *dst, \
55+
TENSOR_TYPE *src, int dimension); \
56+
static void preserveReduceDimSemantics( \
57+
THCState *state, TENSOR_TYPE *tensor, \
58+
int in_dims, int64_t dimension, int keepdim); \
5459
static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t); \
5560
static ptrdiff_t getNumElements(THCState* state, TENSOR_TYPE* t); \
5661
static int64_t getSize(THCState* state, TENSOR_TYPE* t, int dim); \

torch/lib/THC/generic/THCTensorMathReduce.cu

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ THC_API void
7777
THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
7878
{
7979
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
80+
81+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
82+
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
8083
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
8184
THLongStorage_set(dim, dimension, 1);
8285
THCTensor_(resize)(state, self_, dim, NULL);
@@ -103,6 +106,9 @@ THC_API void
103106
THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
104107
{
105108
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));
109+
110+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
111+
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
106112
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
107113
THLongStorage_set(dim, dimension, 1);
108114
THCTensor_(resize)(state, self_, dim, NULL);
@@ -383,17 +389,22 @@ THCTensor_(median)(THCState *state,
383389

384390
THCTensor_(sort)(state, sorted, sorted_indices, self, dimension, 0);
385391

386-
THCTensor_(narrow)(state, values, sorted, dimension, k, 1);
387-
THCudaLongTensor_narrow(state, indices, sorted_indices, dimension, k, 1);
388-
389-
THCTensor_(free)(state, sorted);
390-
THCudaLongTensor_free(state, sorted_indices);
392+
THCTensor *newValues = THCTensor_(newNarrow)(state, sorted, dimension, k, 1);
393+
THCudaLongTensor *newIndices = THCudaLongTensor_newNarrow(state, sorted_indices, dimension, k, 1);
391394

392395
if (!keepdim) {
393-
THCTensor_(squeeze1d)(state, values, values, dimension);
394-
THCudaLongTensor_squeeze1d(state, indices, indices, dimension);
396+
THCTensor_(squeeze1d)(state, newValues, newValues, dimension);
397+
THCudaLongTensor_squeeze1d(state, newIndices, newIndices, dimension);
395398
}
396399

400+
THCTensor_(resizeAs)(state, values, newValues);
401+
THCudaLongTensor_resizeAs(state, indices, newIndices);
402+
THCTensor_(copy)(state, values, newValues);
403+
THCudaLongTensor_copy(state, indices, newIndices);
404+
405+
THCTensor_(free)(state, newValues);
406+
THCudaLongTensor_free(state, newIndices);
407+
397408
THCudaCheck(cudaGetLastError());
398409
}
399410

torch/lib/THC/generic/THCTensorMode.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,10 @@ THC_API void THCTensor_(mode)(THCState *state,
180180

181181
// Resize output value, index Tensors to appropriate sizes (i.e. the same as
182182
// the input Tensor, except at dim=dimension, the size is 1)
183+
TensorUtils<THCTensor>::preserveReduceDimSemantics(
184+
state, values, ndim, dimension, keepdim);
185+
TensorUtils<THCudaLongTensor>::preserveReduceDimSemantics(
186+
state, indices, ndim, dimension, keepdim);
183187
dim = THCTensor_(newSizeOf)(state, input);
184188
THLongStorage_set(dim, dimension, 1);
185189
THCTensor_(resize)(state, values, dim, NULL);

0 commit comments

Comments
 (0)