Skip to content

Commit 174e1ba

Browse files
Natalia Gimelsheinfacebook-github-bot
authored andcommitted
Small fixes to improve TensorIterator overhead for the common case of inputs and outputs of the same type (#27457)
Summary: 1) Short-circuits computing common type and type promotion logic for the common case of operands and result of the same type 2) Improves performance of checking memory overlap by returning MemoryOverlap::FULL if tensors are the same, skips the call from TensorIterator when tensors are the same 3) Changes the default size of DimVector from 5 to 6, thus allowing it not to be resized for a common case of binary operation. `strides` DimVector is forced to have at least 2*num_tensors elements, which for an operation with 2 inputs and one output is 6 4) If `offset` is 0 (common non-broadcasting case), don't fill `strides` vector with 0-s, because all the values will be subsequently written to. These changes combined improve the overhead from 1.02 us to .74 us for a simple in-place operation. Pull Request resolved: #27457 Test Plan: should be covered by existing tests Differential Revision: D17784532 Pulled By: ngimel fbshipit-source-id: e6a8ee58be5de14461bdbc2e2b0b6d16a96c309f
1 parent 3ac4267 commit 174e1ba

File tree

3 files changed

+56
-23
lines changed

3 files changed

+56
-23
lines changed

aten/src/ATen/MemoryOverlap.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b) {
3939
}
4040

4141
MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) {
42+
if (a == b) return MemOverlapStatus::FULL;
4243
if (!a->is_contiguous() || !b->is_contiguous()) {
4344
return MemOverlapStatus::TOO_HARD;
4445
}

aten/src/ATen/native/TensorIterator.cpp

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using DimMask = TensorIterator::DimMask;
1212
using PtrVector = TensorIterator::PtrVector;
1313
using loop_t = TensorIterator::loop_t;
1414
using loop2d_t = TensorIterator::loop2d_t;
15+
using StrideVector = TensorIterator::StrideVector;
1516

1617
void TensorIterator::reorder_dimensions() {
1718
// Sort the dimensions based on strides in ascending order with reduced dims
@@ -86,19 +87,44 @@ Device compute_device(at::ArrayRef<OperandInfo> operands) {
8687
return kCPU;
8788
}
8889

89-
static std::tuple<Device, ScalarType> compute_common_type_(at::ArrayRef<OperandInfo> operands) {
90+
static std::tuple<Device, ScalarType, bool> compute_common_type_(at::ArrayRef<OperandInfo> operands) {
9091
// See [Result type computation] in TensorIterator.h
9192
auto device = compute_device(operands);
93+
auto common_type = ScalarType::Undefined;
94+
bool all_same_type = true;
95+
for (const auto& op: operands){
96+
if (!op.tensor.defined()) continue;
97+
//don't handle scalars
98+
if (op.tensor.dim() > 0){
99+
ScalarType current = op.tensor.scalar_type();
100+
if (current == ScalarType::Undefined){
101+
all_same_type = false;
102+
break;
103+
}
104+
if (common_type == ScalarType::Undefined) common_type = current;
105+
if (common_type != current) {
106+
all_same_type = false;
107+
break;
108+
}
109+
} else {
110+
all_same_type = false;
111+
break;
112+
}
113+
}
114+
if (all_same_type) {
115+
return std::make_tuple(device, common_type, true);
116+
}
117+
//TODO refactor so that no tensor copies are done
92118
std::vector<Tensor> tensors;
93119
std::transform(std::begin(operands), std::end(operands), std::back_inserter(tensors),
94120
[](const OperandInfo& op) { return op.tensor; });
95121
auto dtype = at::native::result_type(tensors);
96-
auto result = std::make_tuple(device, dtype);
122+
auto result = std::make_tuple(device, dtype, false);
97123
TORCH_INTERNAL_ASSERT(dtype != ScalarType::Undefined);
98124
return result;
99125
}
100126

101-
std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
127+
std::tuple<Device, ScalarType, bool> TensorIterator::compute_common_type() {
102128
return compute_common_type_(operands_);
103129
}
104130

@@ -199,11 +225,13 @@ void TensorIterator::compute_types() {
199225
}
200226
}
201227

202-
if (!compute_common_dtype_only_for_inputs) {
203-
validate_dtype(op, common_dtype, ninputs());
204-
}
205-
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
206-
maybe_promote_common_dtype(op, common_dtype);
228+
if (!std::get<2>(common_type)) {
229+
if (!compute_common_dtype_only_for_inputs) {
230+
validate_dtype(op, common_dtype, ninputs());
231+
}
232+
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
233+
maybe_promote_common_dtype(op, common_dtype);
234+
}
207235
}
208236

209237
if (op.tensor.defined() && op.device != op.tensor.device()) {
@@ -221,8 +249,8 @@ void TensorIterator::compute_types() {
221249
}
222250
}
223251

224-
DimVector TensorIterator::compatible_stride(int element_size) const {
225-
auto stride = DimVector();
252+
StrideVector TensorIterator::compatible_stride(int element_size) const {
253+
auto stride = StrideVector();
226254
int64_t next_stride = element_size;
227255
for (int dim = 0; dim < ndim(); dim++) {
228256
stride.push_back(next_stride);
@@ -369,9 +397,9 @@ int64_t TensorIterator::numel() const {
369397
return numel;
370398
}
371399

372-
DimVector TensorIterator::get_dim_strides(int dim) const {
400+
StrideVector TensorIterator::get_dim_strides(int dim) const {
373401
auto dims = ndim();
374-
auto inner_strides = DimVector();
402+
auto inner_strides = StrideVector();
375403
for (auto& op : operands_) {
376404
inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
377405
}
@@ -478,8 +506,8 @@ void TensorIterator::for_each(loop2d_t loop) {
478506
}
479507
}
480508

481-
DimVector TensorIterator::get_strides() const {
482-
DimVector strides;
509+
StrideVector TensorIterator::get_strides() const {
510+
StrideVector strides;
483511
for (int dim = 0; dim < ndim(); dim++) {
484512
for (int arg = 0; arg < ntensors(); arg++) {
485513
strides.push_back(operands_[arg].stride_bytes[dim]);
@@ -751,9 +779,11 @@ void TensorIterator::compute_strides() {
751779
auto original_shape = op.tensor.sizes();
752780
auto original_stride = op.tensor.strides();
753781
auto element_size_in_bytes = op.tensor.element_size();
754-
755-
op.stride_bytes.resize(ndim(), 0);
756782
auto offset = ndim() - original_shape.size();
783+
if (offset > 0)
784+
op.stride_bytes.resize(ndim(), 0);
785+
else
786+
op.stride_bytes.resize(ndim());
757787
for (size_t i = 0; i < original_shape.size(); i++) {
758788
if (original_shape[i] == 1) {
759789
op.stride_bytes[offset + i] = 0;

aten/src/ATen/native/TensorIterator.h

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
// (See https://github.com/pytorch/pytorch/issues/9515)
5353
//
5454
// Note that TensorIterator currently supports type conversions on 0-dim
55-
// tensors and arithmetic operators. Other type conversions will raise an
55+
// tensors and arithmetic operators. Other type conversions will raise an
5656
// exception.
5757

5858
namespace at {
@@ -71,6 +71,7 @@ struct DimCounter {
7171
};
7272

7373
struct CAFFE2_API OperandInfo {
74+
using StrideVector = SmallVector<int64_t, 6>;
7475
OperandInfo() {}
7576
explicit OperandInfo(const Tensor& t) : tensor(t) {
7677
if (t.defined()) {
@@ -85,7 +86,7 @@ struct CAFFE2_API OperandInfo {
8586
}
8687

8788
/// Stride after broadcasting. The stride is in bytes, not number of elements.
88-
DimVector stride_bytes;
89+
StrideVector stride_bytes;
8990

9091
/// The tensor operand. Note that the strides, data pointer, and
9192
/// other attributes may differ due to dimension reordering and
@@ -134,6 +135,7 @@ enum class CommonDTypeStrategy : uint8_t {
134135
struct CAFFE2_API TensorIterator {
135136
using DimMask = std::bitset<64>;
136137
using PtrVector = SmallVector<char*, 4>;
138+
using StrideVector = SmallVector<int64_t, 6>;
137139

138140
TensorIterator() {}
139141

@@ -254,16 +256,16 @@ struct CAFFE2_API TensorIterator {
254256
/// Create a strides array for a Tensor with shape of this iterator. The
255257
/// parameter `element_size` specifies the size of Tensor's data type in
256258
/// bytes (e.g. `4` for `float`)
257-
DimVector compatible_stride(int element_size) const;
259+
StrideVector compatible_stride(int element_size) const;
258260

259261
/// Inverts the re-ordering done by reorder_dimensions. This can only be
260262
/// called *before* coalesce_dimensions() is called.
261263
DimVector invert_perm(IntArrayRef input) const;
262264

263265
/// Helper functions for CPU iteration
264-
DimVector get_dim_strides(int dim) const;
265-
DimVector get_strides() const;
266-
DimVector get_inner_strides() const { return get_dim_strides(0); }
266+
StrideVector get_dim_strides(int dim) const;
267+
StrideVector get_strides() const;
268+
StrideVector get_inner_strides() const { return get_dim_strides(0); }
267269
PtrVector get_data_ptrs(ArrayRef<char*> base, IntArrayRef counter) const;
268270
PtrVector get_base_ptrs() const;
269271

@@ -328,7 +330,7 @@ struct CAFFE2_API TensorIterator {
328330
void reorder_dimensions();
329331
void permute_dimensions(IntArrayRef perm);
330332
void compute_types();
331-
std::tuple<Device, ScalarType> compute_common_type();
333+
std::tuple<Device, ScalarType, bool> compute_common_type();
332334
void allocate_outputs();
333335
#ifdef BUILD_NAMEDTENSOR
334336
void compute_names();

0 commit comments

Comments
 (0)