Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions aten/src/ATen/TensorOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/Context.h>
#include <ATen/Device.h>
#include <ATen/DeviceGuard.h>
#include <ATen/Layout.h>
#include <ATen/ScalarType.h>
#include <ATen/Tensor.h>
Expand All @@ -28,8 +29,9 @@ struct TensorOptions {

/// Constructs the `TensorOptions` from the type of the given `Tensor`.
/// If the `Tensor` has a CUDA type, the `device_index` will match that of the
/// tensor. See the constructor from `Type` for the semantics w.r.t. the
/// `type()` method.
/// tensor. The `requires_grad` property of the tensor is ignored and set to
/// false in the created `TensorOptions`. See the constructor from `Type` for
/// the semantics w.r.t. the `type()` method.
explicit TensorOptions(Tensor tensor, bool discard_runtime_type = false) {
if (!discard_runtime_type) {
type_ = &tensor.type();
Expand Down Expand Up @@ -84,6 +86,18 @@ struct TensorOptions {
this->dtype(dtype);
}

/// True if all elements of the `TensorOptions` match that of the other.
bool operator==(const TensorOptions& other) const noexcept {
return dtype_ == other.dtype_ && layout_ == other.layout_ &&
device_ == other.device_ && requires_grad_ == other.requires_grad_;
}

/// True if any of the elements of this `TensorOptions` do not match that of
/// the other.
bool operator!=(const TensorOptions& other) const noexcept {
return !(*this == other);
}

/// Discards the runtime type stored if the `TensorOptions` was constructed
/// from a `Tensor` or a `Type`. See the documentation of the constructor from
/// a `Type` for implications on the behavior of the `type()` method on
Expand All @@ -93,13 +107,10 @@ struct TensorOptions {
return *this;
}

// NOTE: These methods are defined in TensorOptions.cpp because I get funny
// linker errors for their missing definition if they're defined in the
// header. Who knows why?

/// Sets the device of the `TensorOptions`.
TensorOptions& device(Device device) {
device_ = std::move(device);
update_underlying_type();
return *this;
}

Expand All @@ -112,17 +123,19 @@ struct TensorOptions {
/// Sets the dtype of the `TensorOptions`.
TensorOptions& dtype(ScalarType dtype) {
dtype_ = dtype;
update_underlying_type();
return *this;
}

/// Sets the layout of the `TensorOptions`.
TensorOptions& layout(Layout layout) {
layout_ = layout;
update_underlying_type();
return *this;
}

/// Sets the `requires_grad` property of the `TensorOptions`.
TensorOptions& requires_grad(bool requires_grad = true) {
TensorOptions& requires_grad(bool requires_grad) {
requires_grad_ = requires_grad;
return *this;
}
Expand Down Expand Up @@ -157,16 +170,28 @@ struct TensorOptions {
if (type_ != nullptr) {
return *type_;
}
return getType(backend(), dtype_);
}

private:
/// Updates any stored underlying type to the current construction axes.
void update_underlying_type() {
if (type_) {
type_ = &type_->toScalarType(dtype_).toBackend(backend());
}
}

// Resolves the ATen backend specified by the current construction axes.
Backend backend() const noexcept {
Backend backend;
if (device_.type() == Device::Type::CPU) {
backend = (layout_ == kStrided) ? kCPU : kSparseCPU;
} else {
backend = (layout_ == kStrided) ? kCUDA : kSparseCUDA;
}
return getType(backend, dtype_);
return backend;
}

protected:
ScalarType dtype_{kFloat};
Device device_{Device::Type::CPU};
Layout layout_{Layout::Strided};
Expand Down Expand Up @@ -209,5 +234,31 @@ inline TensorOptions requires_grad(bool requires_grad = true) {
/// From Tensor.h
inline TensorOptions Tensor::options() const {
return TensorOptions(*this);
}
}

namespace detail {
inline Tensor to(
const Tensor& tensor,
const TensorOptions& options,
bool non_blocking) {
// Don't copy if the options match.
if (tensor.options() == options) {
return tensor;
}
DeviceGuard guard(options.device());
return options.type().copy(tensor, non_blocking);
}
} // namespace detail

inline Tensor Tensor::to(Device device, ScalarType dtype, bool non_blocking) {
return detail::to(*this, options().device(device).dtype(dtype), non_blocking);
}

inline Tensor Tensor::to(ScalarType dtype, bool non_blocking) {
return detail::to(*this, options().dtype(dtype), non_blocking);
}

inline Tensor Tensor::to(Device device, bool non_blocking) {
return detail::to(*this, options().device(device), non_blocking);
}
} // namespace at
5 changes: 5 additions & 0 deletions aten/src/ATen/templates/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ struct Tensor : public detail::TensorBase {
inline Tensor toType(ScalarType t) const;
inline Tensor toBackend(Backend b) const;

/// New-style `to()` methods.
Tensor to(Device device, ScalarType dtype, bool non_blocking = false);
Tensor to(ScalarType dtype, bool non_blocking = false);
Tensor to(Device device, bool non_blocking = false);

/// Returns true if the `Tensor` is actually a `torch::autograd::Variable`.
/// Defined in Type.h because of include order issues.
bool is_variable() const noexcept;
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/api/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ TEST_CASE("modules_cuda", "[cuda]") {
SECTION("1") {
Linear model(5, 2);
model->cuda();
auto x = torch::randn({10, 5}, at::device(at::kCUDA).requires_grad());
auto x = torch::randn({10, 5}, at::device(at::kCUDA).requires_grad(true));

This comment was marked as off-topic.

This comment was marked as off-topic.

auto y = model->forward({x})[0];
Variable s = y.sum();

Expand Down
96 changes: 79 additions & 17 deletions test/cpp/api/tensor.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,85 @@
#include <catch.hpp>

#include <torch/functions.h>

#include <ATen/ATen.h>

TEST_CASE("tensor/device-placement") {
SECTION("DeviceGuard") {
// SECTION("On index zero by default") {
// auto tensor = at::ones({3, 3}, at::kCUDA);
// REQUIRE(tensor.get_device() == 0);
// }

// // right hand side is TensorOptions
// torch::OptionGuard guard = torch::device(torch::kCUDA, 1);
// // convenience wrapper over OptionGuard
// torch::DeviceGuard guard(torch::kCUDA, 1);
// /// default device is CUDA
// torch::DeviceGuard guard(1);

// note that this is separate from DeviceGuard. DeviceGuard should move into the
// detail namespace and do the actual thing. OptionGuard just modifies a
// global singleton of option defaults. It operates at a higher level.
#define REQUIRE_TENSOR_OPTIONS(device_, index_, type_, layout_) \
REQUIRE(tensor.device().type() == at::Device((device_), (index_)).type()); \
REQUIRE(tensor.device().index() == at::Device((device_), (index_)).index()); \
REQUIRE(tensor.dtype() == (type_)); \
REQUIRE(tensor.layout() == (layout_))

TEST_CASE("Tensor/ToDtype") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);

tensor = tensor.to(at::kInt);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);

tensor = tensor.to(at::kChar);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kChar, at::kStrided);

tensor = tensor.to(at::kDouble);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kDouble, at::kStrided);
}

// Not currently supported.
// TEST_CASE("Tensor/ToLayout") {
// auto tensor = at::empty({3, 4});
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
//
// tensor = tensor.to(at::kSparse);
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kSparse);
//
// tensor = tensor.to(at::kStrided);
// REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
// }

TEST_CASE("Tensor/ToDevice", "[cuda]") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);

tensor = tensor.to({at::kCUDA, 1});
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kFloat, at::kStrided);

tensor = tensor.to({at::kCUDA, 0});
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 0, at::kFloat, at::kStrided);

tensor = tensor.to({at::kCUDA, 1});
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kFloat, at::kStrided);

tensor = tensor.to(at::Device(at::kCPU));
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);
}

TEST_CASE("Tensor/ToDeviceAndDtype", "[cuda]") {
auto tensor = at::empty({3, 4});
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kFloat, at::kStrided);

tensor = tensor.to({at::kCUDA, 1}, at::kInt);
REQUIRE_TENSOR_OPTIONS(at::kCUDA, 1, at::kInt, at::kStrided);
}

TEST_CASE("Tensor/ToOptionsRespectsRequiresGrad") {
{
auto tensor = torch::empty({3, 4}, at::requires_grad());
REQUIRE(tensor.requires_grad());

tensor = tensor.to(at::kDouble);
REQUIRE(tensor.requires_grad());
}
{
auto tensor = torch::empty({3, 4});
REQUIRE(!tensor.requires_grad());

tensor = tensor.to(at::kDouble);
REQUIRE(!tensor.requires_grad());
}
}

TEST_CASE("Tensor/ToDoesNotCopyWhenOptionsAreAllTheSame") {
auto tensor = at::empty({3, 4}, at::kFloat);
auto hopefully_not_copy = tensor.to(at::kFloat);
REQUIRE(hopefully_not_copy.data<float>() == tensor.data<float>());
}
15 changes: 15 additions & 0 deletions test/cpp/api/tensor_options.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
#include "catch.hpp"

#include <torch/functions.h>

#include <ATen/Context.h>
#include <ATen/Functions.h>
#include <ATen/TensorOptions.h>

#include <vector>
#include <string>

using namespace at;

// A macro so we don't lose location information when an assertion fails.
Expand Down Expand Up @@ -65,6 +70,16 @@ TEST_CASE("TensorOptions/ConstructsWellFromCPUTensors") {
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
}

TEST_CASE("TensorOptions/ConstructsWellFromVariables") {
auto options = TensorOptions(torch::empty(5));
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
REQUIRE(!options.requires_grad());

options = TensorOptions(torch::empty(5, at::requires_grad()));
REQUIRE_OPTIONS(kCPU, -1, kFloat, kStrided);
REQUIRE(!options.requires_grad());
}

TEST_CASE("Device/ParsesCorrectlyFromString") {
Device device("cpu:0");
REQUIRE(device == Device(kCPU, 0));
Expand Down