Skip to content

Commit d68bd71

Browse files
committed
Speed-up "advanced" indexing operations
This speeds-up "advanced" indexing (indexing a tensor by a tensor) on CPU and GPU. There's still a bunch of work to do, including speeding up indexing by a byte (boolean) mask and speeding up the derivative calculation for advanced indexing. Here's some speed comparisons to indexing on master using a little [benchmark script](https://gist.github.com/colesbury/c369db72aad594e5e032c8fda557d909) with 16 OpenMP threads and on a P100: | Test case | CPU (old vs. new) | CUDA (old vs. new) | |-----------------------|-----------------------|----------------------| | 1024x1024 -> 512x1024 | 225 us vs. **57 us** | 297 us vs. **47 us** | | 1024x1024 -> 1024x512 | 208 us vs. **153 us** | 335 us vs. **54 us** | | 50x50 -> 20000x50 | 617 us vs. **77 us** | 239 us vs. **54 us** | | 50x50 -> 50x20000 | 575 us vs. **236 us** | 262 us vs. **58 us** | | 2x5x10 -> 10 | 65 us vs. **18 us** | 612 us vs. **93 us** |
1 parent 4574ea3 commit d68bd71

19 files changed

+555
-96
lines changed

aten/src/ATen/core/Tensor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ class CAFFE2_API Tensor {
352352
Tensor irfft(int64_t signal_ndim, bool normalized=false, bool onesided=true, IntList signal_sizes={}) const;
353353
Tensor index(TensorList indices) const;
354354
Tensor & index_copy_(int64_t dim, const Tensor & index, const Tensor & source);
355-
Tensor index_put(TensorList indices, const Tensor & values) const;
356-
Tensor & index_put_(TensorList indices, const Tensor & values);
355+
Tensor index_put(TensorList indices, const Tensor & values, bool accumulate=false) const;
356+
Tensor & index_put_(TensorList indices, const Tensor & values, bool accumulate=false);
357357
Tensor inverse() const;
358358
Tensor isclose(const Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) const;
359359
bool is_distributed() const;

aten/src/ATen/core/TensorMethods.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,11 @@ inline Tensor Tensor::index(TensorList indices) const {
314314
inline Tensor & Tensor::index_copy_(int64_t dim, const Tensor & index, const Tensor & source) {
315315
return type().index_copy_(*this, dim, index, source);
316316
}
317-
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values) const {
318-
return type().index_put(*this, indices, values);
317+
inline Tensor Tensor::index_put(TensorList indices, const Tensor & values, bool accumulate) const {
318+
return type().index_put(*this, indices, values, accumulate);
319319
}
320-
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values) {
321-
return type().index_put_(*this, indices, values);
320+
inline Tensor & Tensor::index_put_(TensorList indices, const Tensor & values, bool accumulate) {
321+
return type().index_put_(*this, indices, values, accumulate);
322322
}
323323
inline Tensor Tensor::inverse() const {
324324
return type().inverse(*this);

aten/src/ATen/core/Type.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,8 @@ struct CAFFE2_API Type {
257257
virtual Tensor irfft(const Tensor & self, int64_t signal_ndim, bool normalized, bool onesided, IntList signal_sizes) const = 0;
258258
virtual Tensor index(const Tensor & self, TensorList indices) const = 0;
259259
virtual Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) const = 0;
260-
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values) const = 0;
261-
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values) const = 0;
260+
virtual Tensor index_put(const Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
261+
virtual Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & values, bool accumulate) const = 0;
262262
virtual Tensor inverse(const Tensor & self) const = 0;
263263
virtual Tensor isclose(const Tensor & self, const Tensor & other, double rtol, double atol, bool equal_nan) const = 0;
264264
virtual bool is_distributed(const Tensor & self) const = 0;

aten/src/ATen/native/Indexing.cpp

Lines changed: 184 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
// adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index
2222
// tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]]
2323

24+
#include <ATen/native/Indexing.h>
2425

25-
#include "ATen/ATen.h"
26-
#include "ATen/NativeFunctions.h"
27-
#include "ATen/ExpandUtils.h"
26+
#include <ATen/ATen.h>
27+
#include <ATen/NativeFunctions.h>
28+
#include <ATen/ExpandUtils.h>
29+
#include <ATen/native/TensorIterator.h>
2830

2931
#include <algorithm>
3032
#include <functional>
@@ -33,6 +35,9 @@
3335

3436
namespace at { namespace native {
3537

38+
DEFINE_DISPATCH(index_stub);
39+
DEFINE_DISPATCH(index_put_stub);
40+
3641
[[noreturn]]
3742
static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) {
3843
std::stringstream ss;
@@ -226,34 +231,192 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig)
226231
return std::make_tuple(self, linearIndex);
227232
}
228233

229-
Tensor index(const Tensor & self, TensorList indices) {
230-
AT_CHECK(indices.size() <= (size_t)self.dim(),
231-
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
234+
static bool all_strides_match(TensorList tensors) {
235+
AT_ASSERT(tensors.size() >= 1);
236+
auto strides = tensors[0].strides();
237+
for (auto& tensor : tensors.slice(1)) {
238+
if (!strides.equals(tensor.strides())) {
239+
return false;
240+
}
241+
}
242+
return true;
243+
}
244+
245+
static std::string shapes_as_str(TensorList tensors) {
246+
std::ostringstream os;
247+
bool first = true;
248+
for (auto& tensor : tensors) {
249+
if (tensor.defined()) {
250+
if (!first) {
251+
os << ", ";
252+
}
253+
os << tensor.sizes();
254+
first = false;
255+
}
256+
}
257+
return os.str();
258+
}
259+
260+
struct AdvancedIndex {
261+
AdvancedIndex(const Tensor& src, TensorList indices);
262+
263+
Tensor src;
264+
std::vector<Tensor> indices;
265+
DimVector indexed_sizes;
266+
DimVector indexed_strides;
267+
int64_t dims_before;
268+
int64_t dims_after;
269+
};
270+
271+
static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed,
272+
IntList replacement_shape) {
273+
auto shape = DimVector(src.sizes());
274+
auto strides = DimVector(src.strides());
275+
int end = dims_before + dims_indexed;
276+
shape.erase(shape.begin() + dims_before, shape.begin() + end);
277+
strides.erase(strides.begin() + dims_before, strides.begin() + end);
278+
shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end());
279+
strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0);
280+
return src.as_strided(shape, strides);
281+
}
282+
283+
static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) {
284+
auto orig_shape = index.sizes();
285+
auto shape = DimVector();
286+
shape.append(dims_before, 1);
287+
shape.append(orig_shape.begin(), orig_shape.end());
288+
shape.append(dims_after, 1);
289+
return index.reshape(shape);
290+
}
291+
292+
AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
293+
{
294+
int64_t element_size_bytes = src.type().elementSizeInBytes();
295+
int dims_before = 0, dims_after = 0, dims_indexed = 0;
296+
IntList replacement_shape;
297+
for (size_t dim = 0; dim < indices_list.size(); dim++) {
298+
if (!indices_list[dim].defined()) {
299+
if (dims_indexed == 0) {
300+
dims_before++;
301+
} else {
302+
dims_after++;
303+
}
304+
} else {
305+
dims_indexed++;
306+
replacement_shape = indices_list[dim].sizes();
307+
indexed_sizes.push_back(src.size(dim));
308+
indexed_strides.push_back(src.stride(dim) * element_size_bytes);
309+
}
310+
}
311+
312+
this->dims_before = dims_before;
313+
this->dims_after = dims_after;
314+
this->src = restride_src(src, dims_before, dims_indexed, replacement_shape);
315+
316+
for (auto& index : indices_list) {
317+
if (index.defined()) {
318+
indices.push_back(reshape_indexer(index, dims_before, dims_after));
319+
}
320+
}
232321

233-
Tensor src, linearIndex;
234-
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
235-
return src.take(linearIndex);
322+
// For CUDA tensors, force all index tensors to have the same striding to
323+
// simplify the CUDA kernel.
324+
if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) {
325+
if (!all_strides_match(indices)) {
326+
for (size_t i = 0; i < indices.size(); i++) {
327+
indices[i] = indices[i].contiguous();
328+
}
329+
}
330+
}
236331
}
237332

238-
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value) {
333+
static AdvancedIndex make_info(Tensor self, TensorList orig) {
334+
checkIndexTensorTypes(orig);
335+
// first expand ByteTensor (boolean masks) into 1 or more LongTensors
336+
auto indices = expandByteTensors(self, orig);
337+
// next broadcast all index tensors together
338+
try {
339+
indices = expand_outplace(indices);
340+
} catch (std::exception& e) {
341+
AT_ERROR("shape mismatch: indexing tensors could not be broadcast together"
342+
" with shapes ", shapes_as_str(indices));
343+
}
344+
// add missing null Tensors so that it matches self.dim()
345+
while (indices.size() < (size_t)self.dim()) {
346+
indices.emplace_back();
347+
}
348+
// if the non-null indices are not all adjacent, transpose self and indices
349+
// together so that they're adjacent at the front
350+
if (!hasContiguousSubspace(indices)) {
351+
std::tie(self, indices) = transposeToFront(self, indices);
352+
}
353+
return AdvancedIndex(self, indices);
354+
}
355+
356+
static Tensor make_bogus_tensor(const Tensor& self, const AdvancedIndex& info) {
357+
auto shape = DimVector(info.src.sizes());
358+
auto strides = DimVector(shape.size(), 0);
359+
strides[strides.size() - 1] = 1;
360+
for (int dim = strides.size() - 2; dim >= 0; dim--) {
361+
strides[dim] = strides[dim + 1] * shape[dim + 1];
362+
}
363+
return info.src.as_strided(shape, strides);
364+
}
365+
366+
static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) {
367+
auto builder = TensorIterator::Builder();
368+
builder.dont_compute_common_dtype();
369+
builder.add_output(Tensor(), &info.src.type());
370+
builder.add_input(info.src);
371+
for (auto& index : info.indices) {
372+
builder.add_input(index);
373+
}
374+
return builder.build();
375+
}
376+
377+
static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) {
378+
if (!is_expandable_to(value.sizes(), info.src.sizes())) {
379+
AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(),
380+
" cannot be broadcast to indexing result of shape ", info.src.sizes());
381+
}
382+
auto builder = TensorIterator::Builder();
383+
builder.dont_compute_common_dtype();
384+
builder.dont_resize_outputs();
385+
builder.add_output(info.src);
386+
builder.add_input(value, &info.src.type());
387+
for (auto& index : info.indices) {
388+
builder.add_input(index);
389+
}
390+
return builder.build();
391+
}
392+
393+
Tensor index(const Tensor & self, TensorList indices) {
239394
AT_CHECK(indices.size() <= (size_t)self.dim(),
240395
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
241396

242-
Tensor src, linearIndex, expandedValue;
243-
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
244-
std::tie(expandedValue) = expand_inplace(linearIndex, value);
245-
Tensor dst = src.clone();
246-
return dst.put_(linearIndex, expandedValue);
397+
auto info = make_info(self, indices);
398+
auto iter = make_index_iterator(info);
399+
index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides);
400+
return iter->output();
247401
}
248402

249-
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) {
403+
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
404+
return self.clone().index_put_(indices, value, accumulate);
405+
}
406+
407+
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
250408
AT_CHECK(indices.size() <= (size_t)self.dim(),
251409
"too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
252-
253-
Tensor src, linearIndex, expandedValue;
254-
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
255-
std::tie(expandedValue) = expand_inplace(linearIndex, value);
256-
return src.put_(linearIndex, expandedValue);
410+
if (accumulate && self.type().device_type() == kCUDA) {
411+
Tensor src, linearIndex, expandedValue;
412+
std::tie(src, linearIndex) = makeLinearIndex(self, indices);
413+
std::tie(expandedValue) = expand_inplace(linearIndex, value);
414+
return src.put_(linearIndex, expandedValue, true);
415+
}
416+
auto info = make_info(self, indices);
417+
auto iter = make_index_put_iterator(info, value);
418+
index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate);
419+
return self;
257420
}
258421

259422
Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) {

aten/src/ATen/native/Indexing.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#pragma once
2+
3+
// Indexing tensors by by tensors
4+
5+
#include <ATen/ATen.h>
6+
#include <ATen/native/DispatchStub.h>
7+
8+
namespace at {
9+
struct TensorIterator;
10+
}
11+
12+
namespace at { namespace native {
13+
14+
using index_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides);
15+
using index_put_fn = void(*)(TensorIterator &, IntList indexed_sizes, IntList indexed_strides, bool accumulate);
16+
17+
DECLARE_DISPATCH(index_fn, index_stub);
18+
DECLARE_DISPATCH(index_put_fn, index_put_stub);
19+
20+
}} // namespace at::native

0 commit comments

Comments
 (0)