Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions aten/src/TH/generic/THTensorMath.c
Original file line number Diff line number Diff line change
Expand Up @@ -2226,13 +2226,33 @@ ptrdiff_t THTensor_(numel)(THTensor *t)
return THTensor_(nElement)(t);
}


// Helper function to be used in a reduction operation.
// Due to resize semantics of outputs, if the specified output tensor r_ has
// same size as the output of the reduction operation, then any noncontiguities
// in r_ should be preserved.
// The reduction operation, however, needs to act on r_ with an extra dimension
// (the reduced dimension), so this function "resizes" r_ and preserves its
// noncontiguities if necessary.
void THTensor_(preserveReduceDimSemantics)(
THTensor *r_, int in_dims, int reduce_dimension, int keepdim) {
if (r_ && !keepdim &&
THTensor_(nDimension)(r_) == in_dims - 1 &&
THTensor_(nDimension)(r_) != 0) {
THTensor_(unsqueeze1d)(r_, r_, reduce_dimension);
}
}

void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim)
{
THLongStorage *dim;

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);

int in_dims = THTensor_(nDimension)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);

This comment was marked as off-topic.

This comment was marked as off-topic.

THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(values_, dim, NULL);
Expand Down Expand Up @@ -2314,6 +2334,9 @@ void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);

int in_dims = THTensor_(nDimension)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(values_, dim, NULL);
Expand Down Expand Up @@ -2395,6 +2418,7 @@ void THTensor_(sum)(THTensor *r_, THTensor *t, int dimension, int keepdim)
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);

THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(r_, dim, NULL);
Expand Down Expand Up @@ -2474,6 +2498,7 @@ void THTensor_(prod)(THTensor *r_, THTensor *t, int dimension, int keepdim)
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 2, "dimension %d out of range",
dimension + TH_INDEX_BASE);

THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(r_, dim, NULL);
Expand Down Expand Up @@ -3197,6 +3222,9 @@ void THTensor_(mode)(THTensor *values_, THLongTensor *indices_, THTensor *t, int

THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");

int in_dims = THTensor_(nDimension)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(values_, dim, NULL);
Expand Down Expand Up @@ -3263,6 +3291,9 @@ void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t,
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "dimension out of range");
THArgCheck(k > 0 && k <= t->size[dimension], 2, "selected index out of range");

int in_dims = THTensor_(nDimension)(t);
THTensor_(preserveReduceDimSemantics)(values_, in_dims, dimension, keepdim);
THLongTensor_preserveReduceDimSemantics(indices_, in_dims, dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(values_, dim, NULL);
Expand Down Expand Up @@ -3778,6 +3809,7 @@ void THTensor_(std)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);

THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(r_, dim, NULL);
Expand Down Expand Up @@ -3821,6 +3853,7 @@ void THTensor_(var)(THTensor *r_, THTensor *t, int dimension, int biased, int ke
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);

THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(r_, dim, NULL);
Expand Down Expand Up @@ -3864,6 +3897,7 @@ void THTensor_(norm)(THTensor *r_, THTensor *t, real value, int dimension, int k
THArgCheck(dimension >= 0 && dimension < THTensor_(nDimension)(t), 3, "invalid dimension %d",
dimension + TH_INDEX_BASE);

THTensor_(preserveReduceDimSemantics)(r_, THTensor_(nDimension)(t), dimension, keepdim);
dim = THTensor_(newSizeOf)(t);
THLongStorage_set(dim, dimension, 1);
THTensor_(resize)(r_, dim, NULL);
Expand Down
1 change: 1 addition & 0 deletions aten/src/TH/generic/THTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ TH_API void THTensor_(baddbmm)(THTensor *r_, real beta, THTensor *t, real alpha,
TH_API void THTensor_(match)(THTensor *r_, THTensor *m1, THTensor *m2, real gain);

TH_API ptrdiff_t THTensor_(numel)(THTensor *t);
void THTensor_(preserveReduceDimSemantics)(THTensor *r_, int in_dims, int reduce_dimension, int keepdim);
TH_API void THTensor_(max)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
TH_API void THTensor_(min)(THTensor *values_, THLongTensor *indices_, THTensor *t, int dimension, int keepdim);
TH_API void THTensor_(kthvalue)(THTensor *values_, THLongTensor *indices_, THTensor *t, int64_t k, int dimension, int keepdim);
Expand Down
9 changes: 8 additions & 1 deletion aten/src/THC/THCReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,14 @@ bool THC_reduceDim(THCState* state,

}
}
// Resize out to correspond to the reduced size

// Resize out to correspond to the reduced size with keepdim=True.

// Preserve noncontiguities by unsqueezing out if necessary
TensorUtils<TensorType>::preserveReduceDimSemantics(
state, out, TensorUtils<TensorType>::getDims(state, in), dim, keepdim);

// Resize out
THLongStorage* sizes = TensorUtils<TensorType>::newSizeOf(state, in);
THLongStorage_set(sizes, dim, 1);
TensorUtils<TensorType>::resize(state, out, sizes, NULL);
Expand Down
10 changes: 10 additions & 0 deletions aten/src/THC/THCTensorMathReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "THCNumerics.cuh"
#include "THCReduce.cuh"
#include "THCReduceAll.cuh"
#include "THCTensorCopy.h"
#include "THCThrustAllocator.cuh"
#include <thrust/functional.h>
#include <thrust/device_ptr.h>
Expand Down Expand Up @@ -704,6 +705,15 @@ THC_reduceDimIndex(THCState *state,
dimension < TensorUtils<TensorTypeK>::getDims(state, src),
3, "dimension out of range");


// Unsqueeze tgt1_/tgt_2 if necessary so that their contiguity traits
// are preserved if they are the same size as the correct reduction output.
int src_dims = TensorUtils<TensorTypeK>::getDims(state, src);
TensorUtils<TensorTypeK>::preserveReduceDimSemantics(
state, tgt1_, src_dims, dimension, keepdim);
TensorUtils<TensorTypeIndex>::preserveReduceDimSemantics(
state, tgt2_, src_dims, dimension, keepdim);

THLongStorage *dim = TensorUtils<TensorTypeK>::newSizeOf(state, src);
THLongStorage_set(dim, dimension, 1);
TensorUtils<TensorTypeK>::resize(state, tgt1_, dim, NULL);
Expand Down
26 changes: 26 additions & 0 deletions aten/src/THC/THCTensorTypeUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ TensorUtils<TENSOR_TYPE>::squeeze1d(THCState *state, \
TENSOR_TYPE##_squeeze1d(state, dst, src, dimension); \
} \
\
void \
TensorUtils<TENSOR_TYPE>::unsqueeze1d(THCState *state, \
TENSOR_TYPE *dst, \
TENSOR_TYPE *src, \
int dimension) { \
TENSOR_TYPE##_unsqueeze1d(state, dst, src, dimension); \
} \
\
DATA_TYPE* \
TensorUtils<TENSOR_TYPE>::getData(THCState* state, \
TENSOR_TYPE* t) { \
Expand Down Expand Up @@ -133,6 +141,24 @@ TensorUtils<TENSOR_TYPE>::allContiguous(THCState* state, \
return true; \
} \
\
/* Due to the resize semantics of ops with `out=` keywords, if */ \
/* the output `tensor` has the same shape as the output of the */ \
/* reduction operation, then any noncontiguities in the output */ \
/* `tensor` should be preserved. This needs to be special cased b/c */ \
/* otherwise, when keepdim=False, the implementations of reduction */ \
/* ops resize `tensor` to the reduced size with keepdim=True, and */ \
/* then later squeeze `tensor` to the correct output size, breaking */ \
/* the contiguity guarantees of the resize semantics. */ \
void \
TensorUtils<TENSOR_TYPE>::preserveReduceDimSemantics( \
THCState *state, TENSOR_TYPE *tensor, \
int in_dims, int64_t dimension, int keepdim) {\
int out_dims = TensorUtils<TENSOR_TYPE>::getDims(state, tensor); \
if (out_dims > 0 && !keepdim && out_dims == in_dims - 1) { \
TensorUtils<TENSOR_TYPE>::unsqueeze1d(state, tensor, tensor, dimension);\
} \
} \
\
int \
TensorUtils<TENSOR_TYPE>::getDevice(THCState* state, \
TENSOR_TYPE* t) { \
Expand Down
5 changes: 5 additions & 0 deletions aten/src/THC/THCTensorTypeUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ struct TensorUtils {
TENSOR_TYPE* src); \
static void squeeze1d(THCState *state, TENSOR_TYPE *dst, \
TENSOR_TYPE *src, int dimension); \
static void unsqueeze1d(THCState *state, TENSOR_TYPE *dst, \
TENSOR_TYPE *src, int dimension); \
static void preserveReduceDimSemantics( \
THCState *state, TENSOR_TYPE *tensor, \
int in_dims, int64_t dimension, int keepdim); \
static DATA_TYPE* getData(THCState* state, TENSOR_TYPE* t); \
static ptrdiff_t getNumElements(THCState* state, TENSOR_TYPE* t); \
static int64_t getSize(THCState* state, TENSOR_TYPE* t, int dim); \
Expand Down
25 changes: 18 additions & 7 deletions aten/src/THC/generic/THCTensorMathReduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ THC_API void
THCTensor_(std)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));

TensorUtils<THCTensor>::preserveReduceDimSemantics(
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
THLongStorage_set(dim, dimension, 1);
THCTensor_(resize)(state, self_, dim, NULL);
Expand All @@ -103,6 +106,9 @@ THC_API void
THCTensor_(var)(THCState *state, THCTensor *self_, THCTensor *src, int dimension, int biased, int keepdim)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src));

TensorUtils<THCTensor>::preserveReduceDimSemantics(
state, self_, THCTensor_(nDimension)(state, src), dimension, keepdim);
THLongStorage *dim = THCTensor_(newSizeOf)(state, src);
THLongStorage_set(dim, dimension, 1);
THCTensor_(resize)(state, self_, dim, NULL);
Expand Down Expand Up @@ -383,17 +389,22 @@ THCTensor_(median)(THCState *state,

THCTensor_(sort)(state, sorted, sorted_indices, self, dimension, 0);

THCTensor_(narrow)(state, values, sorted, dimension, k, 1);
THCudaLongTensor_narrow(state, indices, sorted_indices, dimension, k, 1);

THCTensor_(free)(state, sorted);
THCudaLongTensor_free(state, sorted_indices);
THCTensor *newValues = THCTensor_(newNarrow)(state, sorted, dimension, k, 1);
THCudaLongTensor *newIndices = THCudaLongTensor_newNarrow(state, sorted_indices, dimension, k, 1);

if (!keepdim) {
THCTensor_(squeeze1d)(state, values, values, dimension);
THCudaLongTensor_squeeze1d(state, indices, indices, dimension);
THCTensor_(squeeze1d)(state, newValues, newValues, dimension);
THCudaLongTensor_squeeze1d(state, newIndices, newIndices, dimension);
}

THCTensor_(resizeAs)(state, values, newValues);
THCudaLongTensor_resizeAs(state, indices, newIndices);
THCTensor_(copy)(state, values, newValues);
THCudaLongTensor_copy(state, indices, newIndices);

THCTensor_(free)(state, newValues);
THCudaLongTensor_free(state, newIndices);

THCudaCheck(cudaGetLastError());
}

Expand Down
4 changes: 4 additions & 0 deletions aten/src/THC/generic/THCTensorMode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ THC_API void THCTensor_(mode)(THCState *state,

// Resize output value, index Tensors to appropriate sizes (i.e. the same as
// the input Tensor, except at dim=dimension, the size is 1)
TensorUtils<THCTensor>::preserveReduceDimSemantics(
state, values, ndim, dimension, keepdim);
TensorUtils<THCudaLongTensor>::preserveReduceDimSemantics(
state, indices, ndim, dimension, keepdim);
dim = THCTensor_(newSizeOf)(state, input);
THLongStorage_set(dim, dimension, 1);
THCTensor_(resize)(state, values, dim, NULL);
Expand Down
30 changes: 26 additions & 4 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,17 +370,20 @@ def _test_dim_reduction(self, cast):
"mean", "median", "mode", "norm", "prod",
"std", "sum", "var", "max", "min"]

def normfn_attr(t, dim, keepdim=False):
def normfn_attr(t, dim, keepdim=False, out=None):
attr = getattr(torch, "norm")
return attr(t, 2, dim, keepdim)
return attr(t, 2, dim, keepdim, out=out)

for fn_name in dim_red_fns:
fn_attr = getattr(torch, fn_name) if fn_name != "norm" else normfn_attr

def fn(x, dim, keepdim=False):
ans = fn_attr(x, dim, keepdim=keepdim)
def fn(x, dim, keepdim=False, out=None):
ans = fn_attr(x, dim, keepdim=keepdim, out=out)
return ans if not isinstance(ans, tuple) else ans[0]

def fn_tuple(x, dim, keepdim=False, out=None):
return fn_attr(x, dim, keepdim=keepdim, out=out)

def test_multidim(x, dim):
self.assertEqual(fn(x, dim).unsqueeze(dim), fn(x, dim, keepdim=True))
self.assertEqual(x.ndimension() - 1, fn(x, dim).ndimension())
Expand All @@ -405,6 +408,25 @@ def test_multidim(x, dim):
x = cast(torch.randn(dims))
test_multidim(x, singleton_dim)

# check reducing with output kwargs
if fn_name in ['median', 'mode', 'max', 'min']:
y = cast(torch.randn(5, 3))
values = cast(torch.randn(5, 3))
indices = cast(torch.zeros(5, 3).long() - 1)
fn_tuple(y, 1, keepdim=False, out=(values[:, 1], indices[:, 1]))
values_expected, indices_expected = fn_tuple(y, 1, keepdim=False)
self.assertEqual(values[:, 1], values_expected,
'{} values with out= kwarg'.format(fn_name))
self.assertEqual(indices[:, 1], indices_expected,
'{} indices with out= kwarg'.format(fn_name))
continue

x = cast(torch.randn(5, 3))
y = cast(torch.randn(5, 3))
fn(y, 1, keepdim=False, out=x[:, 1])
expected = fn(y, 1, keepdim=False)
self.assertEqual(x[:, 1], expected, '{} with out= kwarg'.format(fn_name))

def test_dim_reduction(self):
self._test_dim_reduction(self, lambda t: t)

Expand Down