Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/MemoryOverlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ MemOverlapStatus get_overlap_status(const Tensor& a, const Tensor& b) {
}

MemOverlapStatus get_overlap_status(TensorImpl* a, TensorImpl* b) {
if (a == b) return MemOverlapStatus::FULL;
if (!a->is_contiguous() || !b->is_contiguous()) {
return MemOverlapStatus::TOO_HARD;
}
Expand Down
62 changes: 46 additions & 16 deletions aten/src/ATen/native/TensorIterator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using DimMask = TensorIterator::DimMask;
using PtrVector = TensorIterator::PtrVector;
using loop_t = TensorIterator::loop_t;
using loop2d_t = TensorIterator::loop2d_t;
using StrideVector = TensorIterator::StrideVector;

void TensorIterator::reorder_dimensions() {
// Sort the dimensions based on strides in ascending order with reduced dims
Expand Down Expand Up @@ -86,19 +87,44 @@ Device compute_device(at::ArrayRef<OperandInfo> operands) {
return kCPU;
}

static std::tuple<Device, ScalarType> compute_common_type_(at::ArrayRef<OperandInfo> operands) {
static std::tuple<Device, ScalarType, bool> compute_common_type_(at::ArrayRef<OperandInfo> operands) {
// See [Result type computation] in TensorIterator.h
auto device = compute_device(operands);
auto common_type = ScalarType::Undefined;
bool all_same_type = true;
for (const auto& op: operands){
if (!op.tensor.defined()) continue;
//don't handle scalars
if (op.tensor.dim() > 0){
ScalarType current = op.tensor.scalar_type();
if (current == ScalarType::Undefined){
all_same_type = false;
break;
}
if (common_type == ScalarType::Undefined) common_type = current;
if (common_type != current) {
all_same_type = false;
break;
}
} else {
all_same_type = false;
break;
}
}
if (all_same_type) {
return std::make_tuple(device, common_type, true);
}
//TODO refactor so that no tensor copies are done
std::vector<Tensor> tensors;
std::transform(std::begin(operands), std::end(operands), std::back_inserter(tensors),
[](const OperandInfo& op) { return op.tensor; });
auto dtype = at::native::result_type(tensors);
auto result = std::make_tuple(device, dtype);
auto result = std::make_tuple(device, dtype, false);
TORCH_INTERNAL_ASSERT(dtype != ScalarType::Undefined);
return result;
}

std::tuple<Device, ScalarType> TensorIterator::compute_common_type() {
std::tuple<Device, ScalarType, bool> TensorIterator::compute_common_type() {
return compute_common_type_(operands_);
}

Expand Down Expand Up @@ -199,11 +225,13 @@ void TensorIterator::compute_types() {
}
}

if (!compute_common_dtype_only_for_inputs) {
validate_dtype(op, common_dtype, ninputs());
}
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
maybe_promote_common_dtype(op, common_dtype);
if (!std::get<2>(common_type)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indentation

if (!compute_common_dtype_only_for_inputs) {
validate_dtype(op, common_dtype, ninputs());
}
if (!compute_common_dtype_only_for_inputs || !op.is_output) {
maybe_promote_common_dtype(op, common_dtype);
}
}

if (op.tensor.defined() && op.device != op.tensor.device()) {
Expand All @@ -221,8 +249,8 @@ void TensorIterator::compute_types() {
}
}

DimVector TensorIterator::compatible_stride(int element_size) const {
auto stride = DimVector();
StrideVector TensorIterator::compatible_stride(int element_size) const {
auto stride = StrideVector();
int64_t next_stride = element_size;
for (int dim = 0; dim < ndim(); dim++) {
stride.push_back(next_stride);
Expand Down Expand Up @@ -369,9 +397,9 @@ int64_t TensorIterator::numel() const {
return numel;
}

DimVector TensorIterator::get_dim_strides(int dim) const {
StrideVector TensorIterator::get_dim_strides(int dim) const {
auto dims = ndim();
auto inner_strides = DimVector();
auto inner_strides = StrideVector();
for (auto& op : operands_) {
inner_strides.push_back(dims == 0 ? 0 : op.stride_bytes[dim]);
}
Expand Down Expand Up @@ -478,8 +506,8 @@ void TensorIterator::for_each(loop2d_t loop) {
}
}

DimVector TensorIterator::get_strides() const {
DimVector strides;
StrideVector TensorIterator::get_strides() const {
StrideVector strides;
for (int dim = 0; dim < ndim(); dim++) {
for (int arg = 0; arg < ntensors(); arg++) {
strides.push_back(operands_[arg].stride_bytes[dim]);
Expand Down Expand Up @@ -751,9 +779,11 @@ void TensorIterator::compute_strides() {
auto original_shape = op.tensor.sizes();
auto original_stride = op.tensor.strides();
auto element_size_in_bytes = op.tensor.element_size();

op.stride_bytes.resize(ndim(), 0);
auto offset = ndim() - original_shape.size();
if (offset > 0)
op.stride_bytes.resize(ndim(), 0);
else
op.stride_bytes.resize(ndim());
for (size_t i = 0; i < original_shape.size(); i++) {
if (original_shape[i] == 1) {
op.stride_bytes[offset + i] = 0;
Expand Down
16 changes: 9 additions & 7 deletions aten/src/ATen/native/TensorIterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
// (See https://github.com/pytorch/pytorch/issues/9515)
//
// Note that TensorIterator currently supports type conversions on 0-dim
// tensors and arithmetic operators. Other type conversions will raise an
// tensors and arithmetic operators. Other type conversions will raise an
// exception.

namespace at {
Expand All @@ -71,6 +71,7 @@ struct DimCounter {
};

struct CAFFE2_API OperandInfo {
using StrideVector = SmallVector<int64_t, 6>;
OperandInfo() {}
explicit OperandInfo(const Tensor& t) : tensor(t) {
if (t.defined()) {
Expand All @@ -85,7 +86,7 @@ struct CAFFE2_API OperandInfo {
}

/// Stride after broadcasting. The stride is in bytes, not number of elements.
DimVector stride_bytes;
StrideVector stride_bytes;

/// The tensor operand. Note that the strides, data pointer, and
/// other attributes may differ due to dimension reordering and
Expand Down Expand Up @@ -134,6 +135,7 @@ enum class CommonDTypeStrategy : uint8_t {
struct CAFFE2_API TensorIterator {
using DimMask = std::bitset<64>;
using PtrVector = SmallVector<char*, 4>;
using StrideVector = SmallVector<int64_t, 6>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: maybe move this magical numbers into constant with good name
super super nit: make Ptr vector same size as StrideVector


TensorIterator() {}

Expand Down Expand Up @@ -254,16 +256,16 @@ struct CAFFE2_API TensorIterator {
/// Create a strides array for a Tensor with shape of this iterator. The
/// parameter `element_size` specifies the size of Tensor's data type in
/// bytes (e.g. `4` for `float`)
DimVector compatible_stride(int element_size) const;
StrideVector compatible_stride(int element_size) const;

/// Inverts the re-ordering done by reorder_dimensions. This can only be
/// called *before* coalesce_dimensions() is called.
DimVector invert_perm(IntArrayRef input) const;

/// Helper functions for CPU iteration
DimVector get_dim_strides(int dim) const;
DimVector get_strides() const;
DimVector get_inner_strides() const { return get_dim_strides(0); }
StrideVector get_dim_strides(int dim) const;
StrideVector get_strides() const;
StrideVector get_inner_strides() const { return get_dim_strides(0); }
PtrVector get_data_ptrs(ArrayRef<char*> base, IntArrayRef counter) const;
PtrVector get_base_ptrs() const;

Expand Down Expand Up @@ -328,7 +330,7 @@ struct CAFFE2_API TensorIterator {
void reorder_dimensions();
void permute_dimensions(IntArrayRef perm);
void compute_types();
std::tuple<Device, ScalarType> compute_common_type();
std::tuple<Device, ScalarType, bool> compute_common_type();
void allocate_outputs();
#ifdef BUILD_NAMEDTENSOR
void compute_names();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/TypeProperties.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ ScalarType result_type(TensorList tensors) {
auto dimResult = ScalarType::Undefined;
auto zeroResult = ScalarType::Undefined;
auto wrappedResult = ScalarType::Undefined;
for (Tensor tensor : tensors) {
for (const Tensor& tensor : tensors) {
if (!tensor.defined()) {
continue;
}
Expand Down