|
8 | 8 | ********************************************************/ |
9 | 9 |
|
10 | 10 | #include <Array.hpp> |
| 11 | +#include <common/cast.hpp> |
11 | 12 | #include <common/half.hpp> |
| 13 | +#include <common/moddims.hpp> |
12 | 14 | #include <err_opencl.hpp> |
13 | 15 | #include <index.hpp> |
14 | 16 | #include <sort.hpp> |
15 | 17 | #include <sort_index.hpp> |
16 | 18 | #include <types.hpp> |
| 19 | +#include <handle.hpp> |
| 20 | +#include <arith.hpp> |
| 21 | +#include <range.hpp> |
17 | 22 |
|
18 | 23 | #include <algorithm> |
19 | 24 | #include <cmath> |
@@ -157,12 +162,39 @@ void topk(Array<T>& vals, Array<unsigned>& idxs, const Array<T>& in, |
157 | 162 | vals = values; |
158 | 163 | idxs = indices; |
159 | 164 | } else { |
160 | | - auto values = createEmptyArray<T>(in.dims()); |
161 | | - auto indices = createEmptyArray<unsigned>(in.dims()); |
162 | | - sort_index(values, indices, in, dim, order & AF_TOPK_MIN); |
163 | | - auto indVec = indexForTopK(k); |
164 | | - vals = index<T>(values, indVec.data()); |
165 | | - idxs = index<unsigned>(indices, indVec.data()); |
| 165 | + |
| 166 | + if (!std::is_same_v<T, half>) { |
| 167 | + auto values = createEmptyArray<T>(in.dims()); |
| 168 | + auto indices = createEmptyArray<unsigned>(in.dims()); |
| 169 | + sort_index(values, indices, in, dim, order & AF_TOPK_MIN); |
| 170 | + auto indVec = indexForTopK(k); |
| 171 | + idxs = index<unsigned>(indices, indVec.data()); |
| 172 | + vals = index<T>(values, indVec.data()); |
| 173 | + } else { |
| 174 | + // Temporary implementation for topk due half not being supported in sort_index |
| 175 | + // TODO: Fix sort_index and remove this |
| 176 | + |
| 177 | + auto values = createEmptyArray<float>(in.dims()); |
| 178 | + auto indices = createEmptyArray<unsigned>(in.dims()); |
| 179 | + sort_index(values, indices, common::cast<float>(in), dim, order & AF_TOPK_MIN); |
| 180 | + |
| 181 | + auto indVec = indexForTopK(k); |
| 182 | + idxs = index<unsigned>(indices, indVec.data()); |
| 183 | + |
| 184 | + // Index values from original array by using the indices from the previous resuult |
| 185 | + auto len = in.elements() / in.dims()[dim]; |
| 186 | + auto index_dims = dim4(k, len); |
| 187 | + auto new_indices = common::flat(arithOp<unsigned, af_add_t>(arithOp<unsigned, af_mul_t>(range<unsigned>(index_dims, 1), createValueArray<unsigned>(index_dims, in.dims()[dim]), index_dims), idxs, index_dims)); |
| 188 | + auto indVecVals = indexForTopK(k); |
| 189 | + indVecVals[0].idx.arr = getHandle(new_indices); |
| 190 | + indVecVals[0].isSeq = false; |
| 191 | + indVecVals[0].isBatch = false; |
| 192 | + |
| 193 | + vals = common::modDims(index<T>(common::flat(in), indVecVals.data()), idxs.dims()); |
| 194 | + vals.eval(); |
| 195 | + |
| 196 | + releaseHandle<unsigned>(indVecVals[0].idx.arr); |
| 197 | + } |
166 | 198 | } |
167 | 199 | } |
168 | 200 |
|
|
0 commit comments