Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions aten/src/ATen/native/Distance.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/NativeFunctions.h>

#include <ATen/ExpandUtils.h>
#include <ATen/native/Distance.h>

namespace at { namespace native {
Expand All @@ -25,11 +25,11 @@ Tensor pdist(const Tensor& self, const double p) {
}

Tensor cdist(const Tensor& x1, const Tensor& x2, const double p) {
TORCH_CHECK(x1.dim() == 2, "cdist only supports 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
auto device1 = x1.type().device_type();
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "cdist only supports CPU and CUDA devices, X1 got: ", device1);
TORCH_CHECK(x2.dim() == 2, "cdist only supports 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
auto device2 = x2.type().device_type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "cdist only supports CPU and CUDA devices, X2 got: ", device2);
Expand All @@ -42,12 +42,34 @@ Tensor cdist(const Tensor& x1, const Tensor& x2, const double p) {

int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);
Tensor result = at::empty({r1, r2}, x1.options());
auto dim1 = x1.dim();
auto dim2 = x2.dim();

//For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
//The last two dimensions will stay the same
IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});

int expand_batch_product = std::accumulate(expand_batch_portion.begin(), expand_batch_portion.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};

Tensor tensor1_expanded = x1.expand(tensor1_expand_size).contiguous().view(tensor1_view);
Tensor tensor2_expanded = x2.expand(tensor2_expand_size).contiguous().view(tensor2_view);

std::vector<int64_t> output_shape(expand_batch_portion);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do sometimes you use foo.insert(foo.end(), {bar, quux}); and sometimes you use foo.push_back twice? Should be consistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it consistent

output_shape.insert(output_shape.end(), {r1, r2});
Tensor result = at::empty(output_shape, x1.options());
if (r1 > 0 && r2 > 0) {
if (c1 == 0) {
result.fill_(0);
} else {
cdist_stub(device1, result, x1.contiguous(), x2.contiguous(), p);
cdist_stub(device1, result, tensor1_expanded, tensor2_expanded, p);
}
}
return result;
Expand All @@ -63,7 +85,9 @@ Tensor _cdist_backward(const Tensor& grad, const Tensor& x1, const Tensor& x2, c
TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1);
auto device2 = x2.type().device_type();
TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
Tensor grad_x1 = at::empty({n, m}, x1.options());
IntArrayRef batch_tensor1(x1.sizes().data(), std::max<int64_t>(x1.dim() - 2, 0));
int batch_product = std::accumulate(batch_tensor1.begin(), batch_tensor1.end(), 1, std::multiplies<int64_t>());
Tensor grad_x1 = at::empty_like(x1, x1.options()).view({batch_product, n, m});
cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist);
return grad_x1;
}
Expand Down
68 changes: 42 additions & 26 deletions aten/src/ATen/native/cpu/DistanceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,25 +197,29 @@ struct Dist {
static void run_parallel_cdist(Tensor& result, const Tensor& t1, const Tensor& t2, const scalar_t p) {
const scalar_t * const t1_start = t1.data<scalar_t>();
const scalar_t * const t2_start = t2.data<scalar_t>();
int64_t d = t1.size(0);
int64_t r1 = t1.size(-2);
int64_t r2 = t2.size(-2);
int64_t m = t1.size(-1);

scalar_t * const res_start = result.data<scalar_t>();
int64_t total = r1 * r2;
int64_t combs = r1 * r2;
int64_t size1 = r1 * m;
int64_t size2 = r2 * m;

parallel_for(0, total, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
parallel_for(0, combs * d, internal::GRAIN_SIZE / (16 * m), [=](int64_t start, int64_t end) {
scalar_t * res = res_start + start;
const scalar_t * const res_end = res_start + end;

int64_t i = start / r2;
int64_t j = start % r2;
int64_t l = start / combs;
int64_t k = start % combs;
int64_t i = k / r2;
int64_t j = k % r2;
i = i * m;
j = j * m;
int64_t size = r2 * m;

while (res != res_end) {
const scalar_t * self_i = t1_start + i;
const scalar_t * self_j = t2_start + j;
const scalar_t * self_i = t1_start + size1 * l + i;
const scalar_t * self_j = t2_start + size2 * l + j;

scalar_t agg = 0;
for (int x = 0; x < m; x++) {
Expand All @@ -227,9 +231,13 @@ struct Dist {

res += 1;
j += m;
if (j == size) {
if (j == size2) {
j = 0;
i += m;
if (i == size1) {
i = 0;
l += 1;
}
}
}
});
Expand Down Expand Up @@ -343,7 +351,10 @@ struct Dist {
const int64_t r1 = t1.size(-2);
const int64_t r2 = t2.size(-2);
const int64_t m = t1.size(-1);
const int64_t gs = grad.stride(1);
const int64_t d = result.size(0);
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;
const int64_t gs = grad.stride(-1);

const scalar_t * const grad_start = grad.data<scalar_t>();
const scalar_t * const dist_start = dist.data<scalar_t>();
Expand All @@ -359,31 +370,36 @@ struct Dist {
scalar_t * res_l = res_start + l * Vec::size();

for (const scalar_t * const res_end = res_start + end * Vec::size(); res_l != res_end; i += Vec::size(), j += Vec::size(), res_l += Vec::size()) {
backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, gs);
backward_down_column_cdist<F>(i, j, res_l, grad_start, dist_start, pvec, r1, r2, m, d, gs, l1_size, l2_size);
}
});
const int64_t remainder = m % Vec::size();
if (remainder) {
backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, gs, remainder);
backward_down_column_cdist<F>(t1_start + (m - remainder), t2_start + (m - remainder), res_start + (m - remainder), grad_start, dist_start, Vec(p), r1, r2, m, d, gs, l1_size, l2_size, remainder);
}
}

template <typename F>
inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t gs, int64_t count = Vec::size()) {
const scalar_t * const t1_end = t1 + m * r1;
const scalar_t * const t2_end = t2 + m * r2;

for (; t1 != t1_end; t1 += m, res += m) {
const Vec vec_t1 = Vec::loadu(t1, count);
Vec res_vec = Vec::loadu(res, count);

for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
const Vec vec_t2 = Vec::loadu(t2_curr, count);
Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
res_vec = res_vec + res;
}
inline static void backward_down_column_cdist(const scalar_t * t1, const scalar_t * t2, scalar_t * res, const scalar_t * grad_k, const scalar_t * dist_k, const Vec& pvec, int64_t r1, int64_t r2, int64_t m, int64_t d, int64_t gs, int64_t l1_size, int64_t l2_size, int64_t count = Vec::size()) {
const scalar_t * t1_end = t1 + l1_size;
const scalar_t * t2_end = t2 + l2_size;

for (int64_t l = 0; l < d; l++) {
for (; t1 != t1_end; t1 += m, res += m) {
const Vec vec_t1 = Vec::loadu(t1, count);
Vec res_vec = Vec::loadu(res, count);

for (const scalar_t * t2_curr = t2; t2_curr != t2_end; t2_curr += m, grad_k += gs, dist_k += 1) {
const Vec vec_t2 = Vec::loadu(t2_curr, count);
Vec res = F::backward(vec_t1 - vec_t2, *grad_k, *dist_k, pvec);
res_vec = res_vec + res;
}

res_vec.store(res, count);
res_vec.store(res, count);
}
t1_end += l1_size;
t2_end += l2_size;
t2 += l2_size;
}
}

Expand Down
73 changes: 43 additions & 30 deletions aten/src/ATen/native/cuda/DistanceKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -134,27 +134,30 @@ __global__ static void pdist_kernel_cuda_impl(scalar_t * result, const scalar_t

template <typename scalar_t, typename F>
__global__ static void cdist_backward_kernel_cuda_impl(scalar_t * buffer, const scalar_t * grad, const scalar_t * x1, const scalar_t * x2, const scalar_t * dist, int64_t gs,
const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count) {
const int k = blockIdx.y * blockDim.y + threadIdx.y;
const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t count, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
const int y = blockIdx.y * blockDim.y + threadIdx.y;
const int l = y / r_size;
const int k = y % r_size;
const int init = blockIdx.x * blockDim.x + threadIdx.x;
const int stride = blockDim.x * gridDim.x;
const int l_size = r_size * m;

if (k >= count) {
if (y >= count) {
return;
}

int64_t i = k / r2;
int64_t j = k % r2;

const scalar_t grad_k = grad[k * gs];
const scalar_t dist_k = dist[k];
const scalar_t grad_k = grad[y * gs];
const scalar_t dist_k = dist[y];

const scalar_t * const start = x1 + i * m;
const scalar_t * const start = x1 + l * l1_size + i * m;
const scalar_t * const end = start + m;
const scalar_t * self_i = start + init;
const scalar_t * self_j = x2 + j * m + init;
const scalar_t * self_j = x2 + l * l2_size + j * m + init;

scalar_t * buff_i = buffer + (r1 * j + i) * m + init;
scalar_t * buff_i = buffer + l * l_size + (r1 * j + i) * m + init;

for (; self_i < end; self_i += stride, self_j += stride, buff_i += stride) {
const scalar_t res = F::backward(*self_i - *self_j, grad_k, dist_k, p);
Expand Down Expand Up @@ -196,45 +199,51 @@ __global__ static void pdist_backward_kernel_cuda_impl(scalar_t * buffer, const
}

template <typename scalar_t, typename F>
__global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2, const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m) {
const int k = blockIdx.x;
__global__ static void cdist_kernel_cuda_impl(scalar_t * result, const scalar_t * x1, const scalar_t * x2,
const scalar_t p, const int64_t r1, const int64_t r2, const int64_t m, const int64_t r_size, const int64_t l1_size, const int64_t l2_size) {
const int64_t l = blockIdx.x / r_size;
const int64_t k = blockIdx.x % r_size;
const int64_t i = k / r2;
const int64_t j = k % r2;
const int stride = blockDim.x;

const scalar_t * const start = x1 + i * m;
const scalar_t * const start = x1 + l * l1_size + i * m;
const scalar_t * const end = start + m;
const scalar_t * a = start + threadIdx.x;
const scalar_t * b = x2 + j * m + threadIdx.x;
const scalar_t * b = x2 + l * l2_size + j * m + threadIdx.x;

scalar_t agg = 0.0;
for (; a < end; a += stride, b += stride) {
F::inc(agg, std::abs(*a - *b), p);
}
agg = reduce_agg<scalar_t, F>(agg);
if (threadIdx.x == 0) {
result[k] = F::finish(agg, p);
result[blockIdx.x] = F::finish(agg, p);
}
}

void cdist_kernel_impl(Tensor& result, const Tensor& x1, const Tensor& x2, double p) {
int64_t r1 = x1.size(-2);
int64_t r2 = x2.size(-2);
int64_t m = x1.size(-1);
const dim3 grid(r1*r2);
const int64_t r1 = x1.size(-2);
const int64_t r2 = x2.size(-2);
const int64_t m = x1.size(-1);
const int64_t d = x1.size(0);
const int64_t r_size = r1 * r2;
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;
const dim3 grid(result.numel());
const dim3 block(std::min((int64_t)forward_threads, ((m - 1) / WARP_SIZE + 1) * WARP_SIZE));

AT_DISPATCH_FLOATING_TYPES(x1.scalar_type(), "cdist_cuda", [&] {
if (p == 0.0) {
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::zero><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m, r_size, l1_size, l2_size);
} else if (p == 1.0) {
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m, r_size, l1_size, l2_size);
} else if (p == 2.0) {
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m, r_size, l1_size, l2_size);
} else if (std::isinf(p)) {
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m, r_size, l1_size, l2_size);
} else {
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m);
cdist_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(result.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), p, r1, r2, m, r_size, l1_size, l2_size);
}
});
AT_CUDA_CHECK(cudaGetLastError());
Expand Down Expand Up @@ -316,33 +325,37 @@ void cdist_backward_kernel_impl(Tensor& result, const Tensor& grad, const Tensor
const int64_t r1 = x1.size(-2);
const int64_t r2 = x2.size(-2);
const int64_t m = x1.size(-1);
const int64_t d = x1.size(0);
const int block_x = 64;
const int block_y = 16;
const int grid_x = (m + block_x * 8 - 1) / (block_x * 8);
const int grid_y = (dist.numel() + block_y - 1) / block_y;
const int grid_y = ((r1 * r2 * d) + block_y - 1) / block_y;

const dim3 grid(grid_x, grid_y);
const dim3 block(block_x, block_y);

const int64_t count = dist.numel();
const int64_t r_size = r1 * r2;
const int64_t l1_size = r1 * m;
const int64_t l2_size = r2 * m;

Tensor buffer = at::empty({r2, r1, m}, result.options());
Tensor buffer = at::empty({d, r2, r1, m}, result.options());
AT_DISPATCH_FLOATING_TYPES(result.scalar_type(), "cdist_cuda_backward", [&] {
if (p == 1.0) {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::one><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size);
} else if (p < 2.0) {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::lt_two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size);
} else if (p == 2.0) {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::two><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size);
} else if (std::isinf(p)) {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::inf><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size);
} else {
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(1), p, r1, r2, m, count);
cdist_backward_kernel_cuda_impl<scalar_t, dists<scalar_t>::p><<<grid, block>>>(buffer.data<scalar_t>(), grad.data<scalar_t>(), x1.data<scalar_t>(), x2.data<scalar_t>(), dist.data<scalar_t>(), grad.stride(-1), p, r1, r2, m, count, r_size, l1_size, l2_size);
}
});
AT_CUDA_CHECK(cudaGetLastError());

at::sum_out(result, buffer, 0);
at::sum_out(result, buffer, 1);
}


Expand Down
Loading