Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions test/cpp/api/cursor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,39 @@ TEST_CASE("cursor/parameter") {
}
}
}

TEST_CASE("cursor/non-const-to-const-conversion") {
auto first = std::make_shared<TestModule>(1);
auto second = std::make_shared<TestModule>(2);
Container model(first, second);

{
ConstModuleCursor const_cursor(model.modules());
{
ModuleCursor cursor = model.modules();
ConstModuleCursor const_cursor = cursor;
}
}
{
ConstParameterCursor const_cursor(model.parameters());
{
ParameterCursor cursor = model.parameters();
ConstParameterCursor const_cursor = cursor;
}
}
{
ConstBufferCursor const_cursor(model.buffers());
{
BufferCursor cursor = model.buffers();
ConstBufferCursor const_cursor = cursor;
}
}
}

TEST_CASE("cursor/can-invoke-const-method-on-const-cursor") {
TestModule model(1);

/// This will only compile if `Cursor` has the appropriate const methods.
const auto cursor = model.parameters();
REQUIRE(cursor.contains("tensor1"));
}
2 changes: 1 addition & 1 deletion test/cpp/api/integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include <torch/nn/modules/linear.h>
#include <torch/optimizers.h>
#include <torch/tensor.h>
#include <torch/tensor_range.h>
#include <torch/tensor_list_view.h>
#include <torch/utils.h>

#include <test/cpp/api/util.h>
Expand Down
95 changes: 73 additions & 22 deletions test/cpp/api/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,51 @@ TEST_CASE("module/name") {
}

TEST_CASE("module/conversions", "[cuda]") {
auto module = LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2));
Linear module(128, 64);
SECTION("starts as float on CPU") {
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().backend() == at::kCPU);
REQUIRE(parameter->type().scalarType() == torch::kFloat32);
REQUIRE(parameter->device() == at::Device(at::kCPU));
REQUIRE(parameter->dtype() == torch::kFloat32);
}
}
SECTION("to(CUDA)") {
module->cuda();
module->to({at::kCUDA, 0});
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().backend() == at::kCUDA);
REQUIRE(parameter->device().type() == at::Device::Type::CUDA);
REQUIRE(parameter->device().index() == 0);
}
module->cuda(1);
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->device().type() == at::Device::Type::CUDA);
REQUIRE(parameter->device().index() == 1);
}
}
SECTION("to(CPU)") {
module->to(at::kCPU);
module->to(at::Device(at::kCPU));
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().backend() == at::kCPU);
REQUIRE(parameter->device().type() == at::Device::Type::CPU);
}
}
SECTION("to(Int)") {
SECTION("to(Int32)") {
module->to(torch::kInt32);
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().scalarType() == torch::kInt32);
REQUIRE(parameter->dtype() == torch::kInt32);
}
}
SECTION("to(Double)") {
SECTION("to(Float64)") {
module->to(torch::kFloat64);
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().scalarType() == torch::kFloat64);
REQUIRE(parameter->dtype() == torch::kFloat64);
}
}
SECTION("to(CUDA(Float))") {
module->to(at::CUDA(torch::kFloat32));
SECTION("to(CUDA, Byte)") {
module->to(at::Device(at::kCUDA, 1), torch::kUInt8);
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->device().type() == at::Device::Type::CUDA);
REQUIRE(parameter->device().index() == 1);
}
for (auto& parameter : module->parameters()) {
REQUIRE(parameter->type().backend() == at::kCUDA);
REQUIRE(parameter->type().scalarType() == torch::kFloat32);
REQUIRE(parameter->dtype() == torch::kUInt8);
}
}
}
Expand Down Expand Up @@ -133,23 +142,40 @@ TEST_CASE("module/clone") {
l1 = register_module("l1", Linear(10, 3));
l2 = register_module("l2", Linear(3, 5));
l3 = register_module("l3", Linear(5, 100));
buffer = register_buffer("buf", torch::ones({2, 2}));
}

Linear l1, l2, l3;
torch::Tensor buffer;
};

auto module = TestModule().build();

auto module2 = module->clone();
auto m1param = module->parameters();
auto m2param = module2->parameters();
for (auto& param : m1param) {
REQUIRE(!pointer_equal(param.value, m2param[param.key]));
REQUIRE(param->allclose(m2param[param.key]));
auto params1 = module->parameters();
auto params2 = module2->parameters();
REQUIRE(params1.size() == 6);
REQUIRE(params2.size() == 6);
for (auto& param : params1) {
REQUIRE(!pointer_equal(param.value, params2[param.key]));
REQUIRE(param->allclose(params2[param.key]));
param->data().mul_(2);
}
for (auto& param : m1param) {
REQUIRE(!param->allclose(m2param[param.key]));
for (auto& param : params1) {
REQUIRE(!param->allclose(params2[param.key]));
}

auto buffers1 = module->buffers();
auto buffers2 = module2->buffers();
REQUIRE(buffers1.size() == 1);
REQUIRE(buffers2.size() == 1);
for (auto& buffer : buffers1) {
REQUIRE(!pointer_equal(buffer.value, buffers2[buffer.key]));
REQUIRE(buffer->allclose(buffers2[buffer.key]));
buffer->data().mul_(2);
}
for (auto& buffer : buffers1) {
REQUIRE(!buffer->allclose(buffers2[buffer.key]));
}
}

Expand Down Expand Up @@ -229,3 +255,28 @@ TEST_CASE("module/parameters") {
REQUIRE(parameters.contains("c"));
}
}

TEST_CASE("module/buffers") {
struct TestModule : Module {
TestModule() {
a = register_buffer("a", torch::zeros({2, 2}));
b = register_buffer("b", torch::ones({2, 2}));
c = register_buffer("c", torch::ones({2, 2}) * 2);
}

torch::Tensor a, b, c;
};

TestModule module;

SECTION("has correct number of buffers") {
REQUIRE(module.buffers().size() == 3);
}

SECTION("contains buffers with the correct name") {
auto buffers = module.buffers();
REQUIRE(buffers.contains("a"));
REQUIRE(buffers.contains("b"));
REQUIRE(buffers.contains("c"));
}
}
18 changes: 18 additions & 0 deletions test/cpp/api/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,24 @@ TEST_CASE("Tensor/ContainsCorrectValuesForManyValues") {
REQUIRE(almost_equal(tensor[2], 3.125));
}

TEST_CASE("Tensor/ContainsCorrectValuesForManyValuesVariable") {
auto tensor = torch::tensor({1, 2, 3});
REQUIRE(tensor.is_variable());
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kInt);
REQUIRE(exactly_equal(tensor[0], 1));
REQUIRE(exactly_equal(tensor[1], 2));
REQUIRE(exactly_equal(tensor[2], 3));

tensor = torch::tensor({1.5, 2.25, 3.125});
REQUIRE(tensor.is_variable());
REQUIRE(tensor.numel() == 3);
REQUIRE(tensor.dtype() == at::kDouble);
REQUIRE(almost_equal(tensor[0], 1.5));
REQUIRE(almost_equal(tensor[1], 2.25));
REQUIRE(almost_equal(tensor[2], 3.125));
}

TEST_CASE("Tensor/ContainsCorrectValuesWhenConstructedFromVector") {
std::vector<int> v = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
auto tensor = at::tensor(v);
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/python_torch_functions_dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using at::Storage;
using at::TensorOptions;

static at::Type& default_type() {
return torch::tensor::get_default_tensor_type();
return torch::tensors::get_default_tensor_type();
}

static void maybe_initialize_cuda(const at::Type &type) {
Expand Down
31 changes: 31 additions & 0 deletions tools/autograd/templates/variable_factories.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,38 @@
#include <torch/csrc/autograd/variable.h>

#include <ATen/ATen.h>
#include <ATen/ArrayRef.h>

#include <initializer_list>
#include <utility>

namespace torch {

#define TENSOR(T, S, _1) \
inline autograd::Variable tensor( \
at::ArrayRef<T> values, const at::TensorOptions& options) { \
at::Tensor result = at::tensor(values, options.discard_runtime_type()); \
return autograd::make_variable(result, options.requires_grad()); \
} \
inline autograd::Variable tensor( \
std::initializer_list<T> values, const at::TensorOptions& options) { \
return torch::tensor(at::ArrayRef<T>(values), options); \
} \
inline autograd::Variable tensor( \
T value, const at::TensorOptions& options) { \
return torch::tensor(at::ArrayRef<T>(value), options); \
} \
inline autograd::Variable tensor(at::ArrayRef<T> values) { \
return torch::tensor(std::move(values), at::dtype(at::k##S)); \
} \
inline autograd::Variable tensor(std::initializer_list<T> values) { \
return torch::tensor(at::ArrayRef<T>(values)); \
} \
inline autograd::Variable tensor(T value) { \
return torch::tensor(at::ArrayRef<T>(value)); \
}
AT_FORALL_SCALAR_TYPES_EXCEPT_HALF(TENSOR)
#undef TENSOR

${function_definitions}
} // namespace torch
2 changes: 2 additions & 0 deletions torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ add_custom_command(
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.cpp"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions.h"
"${TORCH_SRC_DIR}/csrc/autograd/generated/python_nn_functions_dispatch.h"
"${TORCH_SRC_DIR}/csrc/autograd/generated/variable_factories.h"
"${TORCH_SRC_DIR}/csrc/jit/generated/aten_dispatch.cpp"
"${TORCH_SRC_DIR}/csrc/jit/generated/aten_schema.cpp"
"${TORCH_SRC_DIR}/csrc/jit/generated/aten_interned_strings.h"
Expand All @@ -178,6 +179,7 @@ add_custom_command(
"${TOOLS_PATH}/autograd/templates/python_nn_functions.cpp"
"${TOOLS_PATH}/autograd/templates/python_nn_functions.h"
"${TOOLS_PATH}/autograd/templates/python_nn_functions_dispatch.h"
"${TOOLS_PATH}/autograd/templates/variable_factories.h"
"${TOOLS_PATH}/autograd/gen_autograd.py"
"${TOOLS_PATH}/autograd/gen_autograd_functions.py"
"${TOOLS_PATH}/autograd/gen_variable_type.py"
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
}
torch::utils::initializeLayouts();
torch::utils::initializeDtypes();
torch::tensor::initialize_python_bindings();
torch::tensors::initialize_python_bindings();
std::string path = THPUtils_unpackString(shm_manager_path);
libshm_init(path.c_str());

Expand Down Expand Up @@ -149,15 +149,15 @@ static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
PyObject * THPModule_setDefaultTensorType(PyObject *_unused, PyObject *type)
{
HANDLE_TH_ERRORS
torch::tensor::py_set_default_tensor_type(type);
torch::tensors::py_set_default_tensor_type(type);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

PyObject * THPModule_setDefaultDtype(PyObject *_unused, PyObject *dtype)
{
HANDLE_TH_ERRORS
torch::tensor::py_set_default_dtype(dtype);
torch::tensors::py_set_default_dtype(dtype);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
Expand Down Expand Up @@ -364,7 +364,7 @@ PyObject *THPModule_setFlushDenormal(PyObject *_unused, PyObject *arg) {

PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {
HANDLE_TH_ERRORS
auto& type = torch::tensor::get_default_tensor_type();
auto& type = torch::tensors::get_default_tensor_type();
auto dtype = (PyObject*)torch::getDtype(type.scalarType());
Py_INCREF(dtype);
return dtype;
Expand All @@ -373,7 +373,7 @@ PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) {

PyObject *THPModule_isDefaultTypeCuda(PyObject *_unused, PyObject *arg) {
HANDLE_TH_ERRORS
if (torch::tensor::get_default_tensor_type().is_cuda()) {
if (torch::tensors::get_default_tensor_type().is_cuda()) {
Py_RETURN_TRUE;
}
Py_RETURN_FALSE;
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/api/include/torch/nn/cloneable.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ class Cloneable : public Module {
const auto& self = static_cast<const Derived&>(*this);
auto copy = std::make_shared<Derived>(self);
copy->parameters_.clear();
copy->buffers_.clear();
copy->children_.clear();
copy->reset();
for (const auto& parameter : parameters_) {
copy->parameters_[parameter.key].data().copy_(parameter->data());
}
for (const auto& buffer : buffers_) {
copy->buffers_[buffer.key].data().copy_(buffer->data());
}
for (const auto& child : children_) {
copy->children_[child.key]->clone_(*child.value);
}
Expand Down
Loading