Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions aten/src/TH/generic/THTensorMoreMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2065,11 +2065,29 @@ void THTensor_(renorm)(THTensor *res, THTensor *src, scalar_t value, int dimensi

accreal THTensor_(dist)(THTensor *tensor, THTensor *src, scalar_t value)
{
scalar_t sum = 0;
TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src,
sum += TH_MATH_NAME(pow)(
TH_MATH_NAME(fabs)(*tensor_data - *src_data), value););
return TH_MATH_NAME(pow)(sum, 1.0/value);
scalar_t sum;
if (value == INFINITY) {
sum = -1.0;
TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src,
sum = THMax(sum, TH_MATH_NAME(fabs)(*tensor_data - *src_data)););
return sum;
} else if (value == -INFINITY) {
sum = INFINITY;
TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src,
sum = THMin(sum, TH_MATH_NAME(fabs)(*tensor_data - *src_data)););
return sum;
} else if (value == 0.0) {
sum = 0.0;
TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src,
sum += (*tensor_data - *src_data != 0.0););
return sum;
} else {
sum = 0.0;
TH_TENSOR_APPLY2(scalar_t, tensor, scalar_t, src,
sum += TH_MATH_NAME(pow)(
TH_MATH_NAME(fabs)(*tensor_data - *src_data), value););
return TH_MATH_NAME(pow)(sum, 1.0/value);
}
}

accreal THTensor_(meanall)(THTensor *tensor)
Expand Down
18 changes: 15 additions & 3 deletions aten/src/THC/THCTensorMathReduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -292,9 +292,21 @@ struct ThrustTensorDistOp {
__host__ __device__ AccT operator()(T _x, T _y) const {
const AccT x = scalar_cast<AccT>(_x);
const AccT y = scalar_cast<AccT>(_y);
return THCNumerics<AccT>::pow(
THCNumerics<AccT>::abs(THCNumerics<AccT>::sub(x, y)),
exponent);
if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(0))) {
const AccT zero = scalar_cast<AccT>(0);
if (THCNumerics<AccT>::eq(THCNumerics<AccT>::sub(x, y), zero))return zero;
return scalar_cast<AccT>(1);
}
if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(1))) {
return THCNumerics<AccT>::abs(THCNumerics<AccT>::sub(x, y));
} else if (THCNumerics<AccT>::eq(exponent, scalar_cast<AccT, float>(2))) {
return THCNumerics<AccT>::pow(
THCNumerics<AccT>::sub(x, y), exponent);
} else {
return THCNumerics<AccT>::pow(
THCNumerics<AccT>::abs(THCNumerics<AccT>::sub(x, y)),
exponent);
}
}

const AccT exponent;
Expand Down
41 changes: 35 additions & 6 deletions aten/src/THC/generic/THCTensorMathReduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -261,18 +261,47 @@ accreal THCTensor_(dist)(THCState *state, THCTensor *self,
thrust::device_ptr<scalar_t> src_data(THCTensor_(data)(state, src));

THCThrustAllocator thrustAlloc(state);
accreal result = thrust::inner_product(
accreal result;

if (THCNumerics<accreal>::eq(value, scalar_cast<accreal>(INFINITY))) {
result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, scalar_cast<accreal>(0),
ReduceMax<accreal>(),
ThrustTensorDistOp<scalar_t, accreal>(scalar_cast<scalar_t>(1)));
} else if (THCNumerics<accreal>::eq(value, scalar_cast<accreal>(-INFINITY))) {
result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, scalar_cast<accreal>(INFINITY),
ReduceMin<accreal>(),
ThrustTensorDistOp<scalar_t, accreal>(scalar_cast<scalar_t>(1)));
} else if (THCNumerics<accreal>::eq(value, scalar_cast<accreal>(0))) {
result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, scalar_cast<accreal>(0),
thrust::plus<accreal>(),
ThrustTensorDistOp<scalar_t, accreal>(scalar_cast<scalar_t>(0)));
} else {
result = thrust::inner_product(
#if CUDA_VERSION >= 7000
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
self_data, self_data+size, src_data, scalar_cast<accreal>(0),
thrust::plus<accreal>(),
ThrustTensorDistOp<scalar_t, accreal>(value));
self_data, self_data+size, src_data, scalar_cast<accreal>(0),
thrust::plus<accreal>(),
ThrustTensorDistOp<scalar_t, accreal>(value));

result = THCNumerics<accreal>::pow(result, THCNumerics<accreal>::cinv(value));
}
THCTensor_(free)(state, src);
THCTensor_(free)(state, self);

return THCNumerics<accreal>::pow(result, THCNumerics<accreal>::cinv(value));
return result;
}

#endif
Expand Down
12 changes: 11 additions & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from test_torch import _TestTorchMixin

from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_WITH_ROCM, load_tests
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests

# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
Expand Down Expand Up @@ -1938,6 +1938,16 @@ def test_diagonal(self):
def test_diagflat(self):
_TestTorchMixin._test_diagflat(self, dtype=torch.float32, device='cuda')

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
@skipIfRocm
def test_norm(self):
_TestTorchMixin._test_norm(self, device='cuda')

@skipIfRocm
def test_dist(self):
_TestTorchMixin._test_dist(self, device='cuda')

@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
def test_trtrs(self):
_TestTorchMixin._test_trtrs(self, lambda t: t.cuda())
Expand Down
23 changes: 17 additions & 6 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,12 +871,23 @@ def _test_norm(self, device):
def test_norm(self):
self._test_norm(self, device='cpu')

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
@skipIfNoLapack
@skipIfRocm
def test_norm_cuda(self):
self._test_norm(self, device='cuda')
@staticmethod
def _test_dist(self, device):
def run_test(x, y):
for p in [0, 1, 2, 3, 4, inf, -inf]:
dist_xy = torch.dist(x, y, p)
dist_xy_norm = torch.norm(x - y, p)
self.assertEqual(dist_xy, dist_xy_norm)

run_test(torch.randn(5, device=device), torch.randn(5, device=device))

x = torch.zeros(3, device=device)
y = torch.zeros(3, device=device)
y[1] = 1.
run_test(x, y)

def test_dist(self):
self._test_dist(self, device='cpu')

def test_dim_reduction_uint8_overflow(self):
example = [[-1, 2, 1], [5, 3, 6]]
Expand Down