Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions aten/src/ATen/native/Sorting.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <ATen/native/Sorting.h>

#include <ATen/ATen.h>
#include <ATen/NumericUtils.h>
#include <ATen/Parallel.h>
Expand Down Expand Up @@ -31,41 +33,6 @@ namespace {
constexpr int64_t MAX_LEVELS = 300;
constexpr int64_t M_SMALL = 10; // Limit for small subfiles

template <typename Fn>
void dim_apply(TensorList tensors, int64_t dim, Fn f) {
AT_ASSERT(tensors.size() > 0);
auto t = tensors[0];
auto sizes = t.sizes();
int64_t ndim = t.dim();
int64_t itersize = 1;
for (int64_t i = 0; i < ndim; i++) {
if (i != dim) {
itersize *= t.size(i);
}
}
parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) {
std::vector<Tensor> narrowed_tensors;
narrowed_tensors.reserve(tensors.size());
for (int64_t it = i_begin; it < i_end; it++) {
narrowed_tensors.clear();
for (auto ti : tensors) {
int64_t i = it;
Tensor nt = ti;
for (size_t d = 0; d < ndim; d++) {
if (d != dim) {
// this could be avoided for slower-changing dimensions if done
// better
nt = nt.select((d > dim ? 1 : 0), i % sizes[d]);
i = i / sizes[d];
}
}
narrowed_tensors.emplace_back(nt);
}
f(it, narrowed_tensors);
}
});
}

template <typename scalar_t, typename Comp, typename Fn>
void quick_select_template(
TensorAccessor<scalar_t, 1> arr,
Expand Down Expand Up @@ -201,6 +168,43 @@ std::tuple<Tensor, Tensor> kthvalue(
return std::make_tuple(values, indices);
}

std::tuple<Tensor&, Tensor&> topk_out_cpu(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");

_allocate_or_resize_output_with_indices(values, indices, self, dim_, k);
if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
return std::forward_as_tuple(values, indices);
}

topk_stub(kCPU, values, indices, self, k, dim, largest, sorted);

return std::forward_as_tuple(values, indices);
}

std::tuple<Tensor, Tensor> topk(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
Tensor values = at::empty({0}, self.options());
Tensor indices = at::empty({0}, self.options().dtype(kLong));
at::topk_out(values, indices, self, k, dim, largest, sorted);
return std::make_tuple(values, indices);
}

std::tuple<Tensor&, Tensor&> median_out(
Tensor& values,
Tensor& indices,
Expand Down Expand Up @@ -249,5 +253,7 @@ Tensor median_cpu(const Tensor& self) {
return result.view({});
}

DEFINE_DISPATCH(topk_stub);

} // namespace native
} // namespace at
12 changes: 12 additions & 0 deletions aten/src/ATen/native/Sorting.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/native/DispatchStub.h>

namespace at { namespace native {

using topk_fn = void(*)(Tensor&, Tensor&, const Tensor&, int64_t, int64_t, bool, bool);

DECLARE_DISPATCH(topk_fn, topk_stub);

}} // at::native
67 changes: 67 additions & 0 deletions aten/src/ATen/native/SortingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,41 @@
namespace at {
namespace native {

template <typename Fn>
void dim_apply(TensorList tensors, int64_t dim, Fn f) {
AT_ASSERT(tensors.size() > 0);
auto t = tensors[0];
auto sizes = t.sizes();
int64_t ndim = t.dim();
int64_t itersize = 1;
for (int64_t i = 0; i < ndim; i++) {
if (i != dim) {
itersize *= t.size(i);
}
}
parallel_for(0, itersize, 1, [&](int64_t i_begin, int64_t i_end) {
std::vector<Tensor> narrowed_tensors;
narrowed_tensors.reserve(tensors.size());
for (int64_t it = i_begin; it < i_end; it++) {
narrowed_tensors.clear();
for (auto ti : tensors) {
int64_t i = it;
Tensor nt = ti;
for (size_t d = 0; d < ndim; d++) {
if (d != dim) {
// this could be avoided for slower-changing dimensions if done
// better
nt = nt.select((d > dim ? 1 : 0), i % sizes[d]);
i = i / sizes[d];
}
}
narrowed_tensors.emplace_back(nt);
}
f(it, narrowed_tensors);
}
});
}

// ensure we get good values and indices for kthvalue, mode, median
// this will always be with the reducing dim as 1-d
static void _reduction_with_indices_allocate_or_resize_output(
Expand Down Expand Up @@ -44,5 +79,37 @@ static void _reduction_with_indices_allocate_or_resize_output(
}
}

// ensure we get good values and indices for topk
static void _allocate_or_resize_output_with_indices(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t dim_,
int64_t k) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
auto result_sizes = self.sizes().vec();
if (result_sizes.size() > 0) {
result_sizes[dim] = k;
}
if (values.defined()) {
TORCH_CHECK(
self.type() == values.type(),
"output values must be of same type as input");
values.resize_(result_sizes);
} else {
values = at::empty(result_sizes, self.options());
}
if (indices.defined()) {
TORCH_CHECK(
indices.dtype() == kLong, "output indices must be of scalar type Long");
TORCH_CHECK(
indices.device() == self.device(),
"output indices must be on same device as input");
indices.resize_(result_sizes);
} else {
indices = at::empty(result_sizes, self.options().dtype(kLong));
}
}

} // namespace native
} // namespace at
90 changes: 90 additions & 0 deletions aten/src/ATen/native/cpu/SortingKernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/Parallel.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/Sorting.h>
#include <ATen/native/SortingUtils.h>

namespace at { namespace native {

namespace {

static void topk_kernel(
Tensor& values,
Tensor& indices,
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
AT_DISPATCH_ALL_TYPES(self.scalar_type(), "topk_cpu", [&] {
dim_apply(
{self, values, indices},
dim,
[&](int64_t i, TensorList tl) {
auto tmp_values = tl[0].accessor<scalar_t, 1>();
auto mode_values = tl[1].accessor<scalar_t, 1>();
auto mode_indices = tl[2].accessor<int64_t, 1>();

auto n = tmp_values.size(0);
auto use_partial_sort = k * 64 <= n;

using elem_t = std::pair<scalar_t, int64_t>;
std::vector<elem_t> queue(n);
for (int64_t j = 0; j < n; j++) {
queue[j].first = tmp_values[j];
queue[j].second = j;
}

// we want NaN to be sorted as top for numpy compatibility
if (use_partial_sort) {
if (largest) {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
} else {
std::partial_sort(queue.begin(), queue.begin() + k, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
}
} else {
if (largest) {
std::nth_element(queue.begin(), queue.begin() + k - 1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k - 1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((_isnan<scalar_t>(x.first) && !_isnan<scalar_t>(y.first)) || (x.first > y.first));
});
}
} else {
std::nth_element(queue.begin(), queue.begin() + k -1, queue.end(),
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
if (sorted) {
std::sort(queue.begin(), queue.begin() + k -1,
[](const elem_t& x, const elem_t& y) -> bool {
return ((!_isnan<scalar_t>(x.first) && _isnan<scalar_t>(y.first)) || (x.first < y.first));
});
}
}
}

for (int64_t j = 0; j < k; j++) {
mode_values[j] = queue[j].first;
mode_indices[j] = queue[j].second;
}
});
});
}

} // anonymous namespace

REGISTER_DISPATCH(topk_stub, &topk_kernel);

}} //at::native
5 changes: 1 addition & 4 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4081,14 +4081,11 @@

- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) ->(Tensor(a!) values, Tensor(b!) indices)
dispatch:
CPU: legacy::cpu::_th_topk_out
CPU: topk_out_cpu
CUDA: legacy::cuda::_th_topk_out

- func: topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
variants: method, function
dispatch:
CPU: legacy::cpu::_th_topk
CUDA: legacy::cuda::_th_topk

- func: all(Tensor self) -> Tensor
variants: method, function
Expand Down