Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ Device::Device(const std::string& device_string) : Device(Type::CPU) {
}
}

} // namespace at

std::ostream& operator<<(std::ostream& stream, const at::Device& device) {
stream << device.type();
if (device.has_index()) {
stream << ":" << device.index();
}
return stream;
}

} // namespace at
4 changes: 3 additions & 1 deletion aten/src/ATen/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ struct Device {
DeviceType type_;
int32_t index_ = -1;
};
} // namespace at

AT_API std::ostream& operator<<(std::ostream& stream, const at::Device& device);

} // namespace at


namespace std {
template<> struct hash<at::Device>
{
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/Layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ inline Layout layout_from_backend(Backend backend) {
return Layout::Strided;
}
}
} // namespace at

inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
switch (layout) {
Expand All @@ -32,3 +31,5 @@ inline std::ostream& operator<<(std::ostream& stream, at::Layout layout) {
AT_ERROR("Unknown layout");
}
}

} // namespace at
3 changes: 3 additions & 0 deletions aten/src/ATen/TensorGeometry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
namespace at {

bool TensorGeometry::is_contiguous() const {
if (numel_ == 0) {
return true;
}
int64_t dim = sizes_.size();
int64_t expected_stride = 1;
for (int64_t i = dim - 1; i >= 0; i--) {
Expand Down
13 changes: 5 additions & 8 deletions aten/src/ATen/TensorGeometry.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ struct AT_API TensorGeometry {
strides_[i] = expected_stride;
expected_stride *= sizes_[i];
}
numel_ = expected_stride;
}

explicit TensorGeometry(const Tensor& t)
: sizes_(t.sizes().vec())
, strides_(t.strides().vec())
, storage_offset_(t.storage_offset()) {}
, storage_offset_(t.storage_offset())
, numel_(t.numel()) {}

// true if the tensor is contiguous
bool is_contiguous() const;
Expand All @@ -43,13 +45,7 @@ struct AT_API TensorGeometry {
}
IntList strides() const { return IntList{ strides_ }; }
int64_t storage_offset() const { return storage_offset_; }
int64_t numel() const {
int64_t r = 1;
for (auto s : sizes()) {
r *= s;
}
return r;
}
int64_t numel() const { return numel_; }

TensorGeometry transpose(int64_t dim0, int64_t dim1) {
TensorGeometry r = *this; // copy
Expand All @@ -63,6 +59,7 @@ struct AT_API TensorGeometry {
std::vector<int64_t> sizes_;
std::vector<int64_t> strides_;
int64_t storage_offset_;
int64_t numel_;
};

} // namespace at
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) {
oss << "Tensor for " << t2 << " is on CPU, ";
}
oss << "but expected " << ((!(t1->is_cuda() || t2->is_cuda())) ? "them" : "it")
<< " to be on GPU (while checking arguments for " << c << ")";
<< " to be on GPU (while checking arguments for " << c << ")";
AT_ERROR(oss.str());
}
AT_CHECK(
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/DeviceType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ std::string DeviceTypeName(at::DeviceType d, bool lower_case) {
}
}

} // namespace at

std::ostream& operator<<(std::ostream& stream, at::DeviceType type) {
stream << at::DeviceTypeName(type, /* lower case */ true);
return stream;
}

} // namespace at
4 changes: 2 additions & 2 deletions aten/src/ATen/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ AT_CORE_API std::string DeviceTypeName(
at::DeviceType d,
bool lower_case = false);

} // namespace at

AT_CORE_API std::ostream& operator<<(std::ostream& stream, at::DeviceType type);

} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/cuda/detail/KernelUtils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#pragma once

#include "ATen/ATen.h"

// Contents of this file are copied from THCUNN/common.h for the ease of porting
// THCUNN functions into ATen.

Expand All @@ -14,6 +17,7 @@ constexpr int CUDA_NUM_THREADS = 1024;
// CUDA: number of blocks for threads.
inline int GET_BLOCKS(const int N)
{
AT_ASSERTM(N > 0, "CUDA kernel launch blocks must be positive, but got N=", N);
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}

Expand Down
Loading