|
21 | 21 | // adjacent (e.g. x[[0, 1], :, [2, 3]]). In this case, self and the index |
22 | 22 | // tensors are transposed to the front: x.transpose(1, 2)[[0, 1], [2, 3]] |
23 | 23 |
|
| 24 | +#include <ATen/native/Indexing.h> |
24 | 25 |
|
25 | | -#include "ATen/ATen.h" |
26 | | -#include "ATen/NativeFunctions.h" |
27 | | -#include "ATen/ExpandUtils.h" |
| 26 | +#include <ATen/ATen.h> |
| 27 | +#include <ATen/NativeFunctions.h> |
| 28 | +#include <ATen/ExpandUtils.h> |
| 29 | +#include <ATen/native/TensorIterator.h> |
28 | 30 |
|
29 | 31 | #include <algorithm> |
30 | 32 | #include <functional> |
|
33 | 35 |
|
34 | 36 | namespace at { namespace native { |
35 | 37 |
|
| 38 | +DEFINE_DISPATCH(index_stub); |
| 39 | +DEFINE_DISPATCH(index_put_stub); |
| 40 | + |
36 | 41 | [[noreturn]] |
37 | 42 | static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, int64_t maskIdx) { |
38 | 43 | std::stringstream ss; |
@@ -226,34 +231,192 @@ static std::tuple<Tensor, Tensor> makeLinearIndex(Tensor self, TensorList orig) |
226 | 231 | return std::make_tuple(self, linearIndex); |
227 | 232 | } |
228 | 233 |
|
229 | | -Tensor index(const Tensor & self, TensorList indices) { |
230 | | - AT_CHECK(indices.size() <= (size_t)self.dim(), |
231 | | - "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
| 234 | +static bool all_strides_match(TensorList tensors) { |
| 235 | + AT_ASSERT(tensors.size() >= 1); |
| 236 | + auto strides = tensors[0].strides(); |
| 237 | + for (auto& tensor : tensors.slice(1)) { |
| 238 | + if (!strides.equals(tensor.strides())) { |
| 239 | + return false; |
| 240 | + } |
| 241 | + } |
| 242 | + return true; |
| 243 | +} |
| 244 | + |
| 245 | +static std::string shapes_as_str(TensorList tensors) { |
| 246 | + std::ostringstream os; |
| 247 | + bool first = true; |
| 248 | + for (auto& tensor : tensors) { |
| 249 | + if (tensor.defined()) { |
| 250 | + if (!first) { |
| 251 | + os << ", "; |
| 252 | + } |
| 253 | + os << tensor.sizes(); |
| 254 | + first = false; |
| 255 | + } |
| 256 | + } |
| 257 | + return os.str(); |
| 258 | +} |
| 259 | + |
| 260 | +struct AdvancedIndex { |
| 261 | + AdvancedIndex(const Tensor& src, TensorList indices); |
| 262 | + |
| 263 | + Tensor src; |
| 264 | + std::vector<Tensor> indices; |
| 265 | + DimVector indexed_sizes; |
| 266 | + DimVector indexed_strides; |
| 267 | + int64_t dims_before; |
| 268 | + int64_t dims_after; |
| 269 | +}; |
| 270 | + |
| 271 | +static Tensor restride_src(const Tensor& src, int64_t dims_before, int64_t dims_indexed, |
| 272 | + IntList replacement_shape) { |
| 273 | + auto shape = DimVector(src.sizes()); |
| 274 | + auto strides = DimVector(src.strides()); |
| 275 | + int end = dims_before + dims_indexed; |
| 276 | + shape.erase(shape.begin() + dims_before, shape.begin() + end); |
| 277 | + strides.erase(strides.begin() + dims_before, strides.begin() + end); |
| 278 | + shape.insert(shape.begin() + dims_before, replacement_shape.begin(), replacement_shape.end()); |
| 279 | + strides.insert(strides.begin() + dims_before, replacement_shape.size(), 0); |
| 280 | + return src.as_strided(shape, strides); |
| 281 | +} |
| 282 | + |
| 283 | +static Tensor reshape_indexer(const Tensor& index, int64_t dims_before, int64_t dims_after) { |
| 284 | + auto orig_shape = index.sizes(); |
| 285 | + auto shape = DimVector(); |
| 286 | + shape.append(dims_before, 1); |
| 287 | + shape.append(orig_shape.begin(), orig_shape.end()); |
| 288 | + shape.append(dims_after, 1); |
| 289 | + return index.reshape(shape); |
| 290 | +} |
| 291 | + |
| 292 | +AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list) |
| 293 | +{ |
| 294 | + int64_t element_size_bytes = src.type().elementSizeInBytes(); |
| 295 | + int dims_before = 0, dims_after = 0, dims_indexed = 0; |
| 296 | + IntList replacement_shape; |
| 297 | + for (size_t dim = 0; dim < indices_list.size(); dim++) { |
| 298 | + if (!indices_list[dim].defined()) { |
| 299 | + if (dims_indexed == 0) { |
| 300 | + dims_before++; |
| 301 | + } else { |
| 302 | + dims_after++; |
| 303 | + } |
| 304 | + } else { |
| 305 | + dims_indexed++; |
| 306 | + replacement_shape = indices_list[dim].sizes(); |
| 307 | + indexed_sizes.push_back(src.size(dim)); |
| 308 | + indexed_strides.push_back(src.stride(dim) * element_size_bytes); |
| 309 | + } |
| 310 | + } |
| 311 | + |
| 312 | + this->dims_before = dims_before; |
| 313 | + this->dims_after = dims_after; |
| 314 | + this->src = restride_src(src, dims_before, dims_indexed, replacement_shape); |
| 315 | + |
| 316 | + for (auto& index : indices_list) { |
| 317 | + if (index.defined()) { |
| 318 | + indices.push_back(reshape_indexer(index, dims_before, dims_after)); |
| 319 | + } |
| 320 | + } |
232 | 321 |
|
233 | | - Tensor src, linearIndex; |
234 | | - std::tie(src, linearIndex) = makeLinearIndex(self, indices); |
235 | | - return src.take(linearIndex); |
| 322 | + // For CUDA tensors, force all index tensors to have the same striding to |
| 323 | + // simplify the CUDA kernel. |
| 324 | + if (indices.size() >= 2 && this->src.type().device_type() == kCUDA) { |
| 325 | + if (!all_strides_match(indices)) { |
| 326 | + for (size_t i = 0; i < indices.size(); i++) { |
| 327 | + indices[i] = indices[i].contiguous(); |
| 328 | + } |
| 329 | + } |
| 330 | + } |
236 | 331 | } |
237 | 332 |
|
238 | | -Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value) { |
| 333 | +static AdvancedIndex make_info(Tensor self, TensorList orig) { |
| 334 | + checkIndexTensorTypes(orig); |
| 335 | + // first expand ByteTensor (boolean masks) into 1 or more LongTensors |
| 336 | + auto indices = expandByteTensors(self, orig); |
| 337 | + // next broadcast all index tensors together |
| 338 | + try { |
| 339 | + indices = expand_outplace(indices); |
| 340 | + } catch (std::exception& e) { |
| 341 | + AT_ERROR("shape mismatch: indexing tensors could not be broadcast together" |
| 342 | + " with shapes ", shapes_as_str(indices)); |
| 343 | + } |
| 344 | + // add missing null Tensors so that it matches self.dim() |
| 345 | + while (indices.size() < (size_t)self.dim()) { |
| 346 | + indices.emplace_back(); |
| 347 | + } |
| 348 | + // if the non-null indices are not all adjacent, transpose self and indices |
| 349 | + // together so that they're adjacent at the front |
| 350 | + if (!hasContiguousSubspace(indices)) { |
| 351 | + std::tie(self, indices) = transposeToFront(self, indices); |
| 352 | + } |
| 353 | + return AdvancedIndex(self, indices); |
| 354 | +} |
| 355 | + |
| 356 | +static Tensor make_bogus_tensor(const Tensor& self, const AdvancedIndex& info) { |
| 357 | + auto shape = DimVector(info.src.sizes()); |
| 358 | + auto strides = DimVector(shape.size(), 0); |
| 359 | + strides[strides.size() - 1] = 1; |
| 360 | + for (int dim = strides.size() - 2; dim >= 0; dim--) { |
| 361 | + strides[dim] = strides[dim + 1] * shape[dim + 1]; |
| 362 | + } |
| 363 | + return info.src.as_strided(shape, strides); |
| 364 | +} |
| 365 | + |
| 366 | +static std::unique_ptr<TensorIterator> make_index_iterator(const AdvancedIndex& info) { |
| 367 | + auto builder = TensorIterator::Builder(); |
| 368 | + builder.dont_compute_common_dtype(); |
| 369 | + builder.add_output(Tensor(), &info.src.type()); |
| 370 | + builder.add_input(info.src); |
| 371 | + for (auto& index : info.indices) { |
| 372 | + builder.add_input(index); |
| 373 | + } |
| 374 | + return builder.build(); |
| 375 | +} |
| 376 | + |
| 377 | +static std::unique_ptr<TensorIterator> make_index_put_iterator(const AdvancedIndex& info, const Tensor& value) { |
| 378 | + if (!is_expandable_to(value.sizes(), info.src.sizes())) { |
| 379 | + AT_ERROR("shape mismatch: value tensor of shape ", value.sizes(), |
| 380 | + " cannot be broadcast to indexing result of shape ", info.src.sizes()); |
| 381 | + } |
| 382 | + auto builder = TensorIterator::Builder(); |
| 383 | + builder.dont_compute_common_dtype(); |
| 384 | + builder.dont_resize_outputs(); |
| 385 | + builder.add_output(info.src); |
| 386 | + builder.add_input(value, &info.src.type()); |
| 387 | + for (auto& index : info.indices) { |
| 388 | + builder.add_input(index); |
| 389 | + } |
| 390 | + return builder.build(); |
| 391 | +} |
| 392 | + |
| 393 | +Tensor index(const Tensor & self, TensorList indices) { |
239 | 394 | AT_CHECK(indices.size() <= (size_t)self.dim(), |
240 | 395 | "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
241 | 396 |
|
242 | | - Tensor src, linearIndex, expandedValue; |
243 | | - std::tie(src, linearIndex) = makeLinearIndex(self, indices); |
244 | | - std::tie(expandedValue) = expand_inplace(linearIndex, value); |
245 | | - Tensor dst = src.clone(); |
246 | | - return dst.put_(linearIndex, expandedValue); |
| 397 | + auto info = make_info(self, indices); |
| 398 | + auto iter = make_index_iterator(info); |
| 399 | + index_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides); |
| 400 | + return iter->output(); |
247 | 401 | } |
248 | 402 |
|
249 | | -Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value) { |
| 403 | +Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) { |
| 404 | + return self.clone().index_put_(indices, value, accumulate); |
| 405 | +} |
| 406 | + |
| 407 | +Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, bool accumulate) { |
250 | 408 | AT_CHECK(indices.size() <= (size_t)self.dim(), |
251 | 409 | "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")"); |
252 | | - |
253 | | - Tensor src, linearIndex, expandedValue; |
254 | | - std::tie(src, linearIndex) = makeLinearIndex(self, indices); |
255 | | - std::tie(expandedValue) = expand_inplace(linearIndex, value); |
256 | | - return src.put_(linearIndex, expandedValue); |
| 410 | + if (accumulate && self.type().device_type() == kCUDA) { |
| 411 | + Tensor src, linearIndex, expandedValue; |
| 412 | + std::tie(src, linearIndex) = makeLinearIndex(self, indices); |
| 413 | + std::tie(expandedValue) = expand_inplace(linearIndex, value); |
| 414 | + return src.put_(linearIndex, expandedValue, true); |
| 415 | + } |
| 416 | + auto info = make_info(self, indices); |
| 417 | + auto iter = make_index_put_iterator(info, value); |
| 418 | + index_put_stub(iter->device_type(), *iter, info.indexed_sizes, info.indexed_strides, accumulate); |
| 419 | + return self; |
257 | 420 | } |
258 | 421 |
|
259 | 422 | Tensor & index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source) { |
|
0 commit comments