Skip to content

Commit a699cb9

Browse files
committed
Fixed topk for half, marked sort_index with half unsupported
1 parent 82ca3b3 commit a699cb9

2 files changed

Lines changed: 44 additions & 6 deletions

File tree

src/backend/opencl/sort_index.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ namespace opencl {
2525
template<typename T>
2626
void sort_index(Array<T> &okey, Array<uint> &oval, const Array<T> &in,
2727
const uint dim, bool isAscending) {
28+
29+
// TODO: fix half implementation of sort0bykey to support this
30+
if (std::is_same_v<T, half>) {
31+
OPENCL_NOT_SUPPORTED("sort_index with half");
32+
}
33+
2834
try {
2935
// okey contains values, oval contains indices
3036
okey = copyArray<T>(in);

src/backend/opencl/topk.cpp

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88
********************************************************/
99

1010
#include <Array.hpp>
11+
#include <common/cast.hpp>
1112
#include <common/half.hpp>
13+
#include <common/moddims.hpp>
1214
#include <err_opencl.hpp>
1315
#include <index.hpp>
1416
#include <sort.hpp>
1517
#include <sort_index.hpp>
1618
#include <types.hpp>
19+
#include <handle.hpp>
20+
#include <arith.hpp>
21+
#include <range.hpp>
1722

1823
#include <algorithm>
1924
#include <cmath>
@@ -157,12 +162,39 @@ void topk(Array<T>& vals, Array<unsigned>& idxs, const Array<T>& in,
157162
vals = values;
158163
idxs = indices;
159164
} else {
160-
auto values = createEmptyArray<T>(in.dims());
161-
auto indices = createEmptyArray<unsigned>(in.dims());
162-
sort_index(values, indices, in, dim, order & AF_TOPK_MIN);
163-
auto indVec = indexForTopK(k);
164-
vals = index<T>(values, indVec.data());
165-
idxs = index<unsigned>(indices, indVec.data());
165+
166+
if (!std::is_same_v<T, half>) {
167+
auto values = createEmptyArray<T>(in.dims());
168+
auto indices = createEmptyArray<unsigned>(in.dims());
169+
sort_index(values, indices, in, dim, order & AF_TOPK_MIN);
170+
auto indVec = indexForTopK(k);
171+
idxs = index<unsigned>(indices, indVec.data());
172+
vals = index<T>(values, indVec.data());
173+
} else {
174+
// Temporary implementation for topk due half not being supported in sort_index
175+
// TODO: Fix sort_index and remove this
176+
177+
auto values = createEmptyArray<float>(in.dims());
178+
auto indices = createEmptyArray<unsigned>(in.dims());
179+
sort_index(values, indices, common::cast<float>(in), dim, order & AF_TOPK_MIN);
180+
181+
auto indVec = indexForTopK(k);
182+
idxs = index<unsigned>(indices, indVec.data());
183+
184+
// Index values from original array by using the indices from the previous resuult
185+
auto len = in.elements() / in.dims()[dim];
186+
auto index_dims = dim4(k, len);
187+
auto new_indices = common::flat(arithOp<unsigned, af_add_t>(arithOp<unsigned, af_mul_t>(range<unsigned>(index_dims, 1), createValueArray<unsigned>(index_dims, in.dims()[dim]), index_dims), idxs, index_dims));
188+
auto indVecVals = indexForTopK(k);
189+
indVecVals[0].idx.arr = getHandle(new_indices);
190+
indVecVals[0].isSeq = false;
191+
indVecVals[0].isBatch = false;
192+
193+
vals = common::modDims(index<T>(common::flat(in), indVecVals.data()), idxs.dims());
194+
vals.eval();
195+
196+
releaseHandle<unsigned>(indVecVals[0].idx.arr);
197+
}
166198
}
167199
}
168200

0 commit comments

Comments
 (0)