Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions aten/src/ATen/core/ivalue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ std::string ivalue::Object::name() const {
return this->type_->qualname();
}

IValue ivalue::Object::getAttr(const std::string& name) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably want the same api as pybind11, which uses attr

Copy link
Member Author

@suo suo May 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pybind's attr returns a functor which lets you write to the object by assigning to it. I think that's probably a little overkill in this case.

I think i'll expose static getattr and setattr like the python builtins

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old API seems simpler, this makes it more like a C API. Maybe we shouldn't try to match pythons semantics here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with that. whatever is more natural for people I guess

const size_t slot = type_->getAttributeSlot(name);
return getSlot(slot);
}

void ivalue::Object::setAttr(const std::string& name, IValue v) {
const size_t slot = type_->getAttributeSlot(name);
setSlot(slot, std::move(v));
}

void ivalue::Object::resizeObject(size_t slot) {
AT_ASSERT(slot < type()->numAttributes());
slots_.resize(type()->numAttributes());
Expand Down
28 changes: 28 additions & 0 deletions aten/src/ATen/core/ivalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@

#include <ATen/core/Tensor.h>

namespace torch {
namespace jit {
namespace script {
struct Function;
}
} // namespace jit
} // namespace torch
namespace c10 {
struct IValue;
struct ClassType;
Expand Down Expand Up @@ -695,6 +702,14 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
return c10::make_intrusive<Object>(std::move(type), numSlots);
}

/**
* Slot API.
*
* Attributes are stored as a simple vector so that lookups are fast at
* runtime. A "slot" is just an index into that vector, which can be computed
* statically if you have access to the class type. Use this API if you are
* writing compiler stuff.
*/
void setSlot(size_t slot, IValue v) {
if (slot >= slots_.size()) {
// for module types, it is possible that the members of the class have
Expand All @@ -709,6 +724,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
return slots_.at(slot);
}

/**
* Attribute API.
*
* Wrappers around the slot stuff so that users can access attributes
* directly. Use this API if you are a user.
*
* Note: Unlike in Python, TorchScript must make a distinction between
* attributes (which are IValues) and methods (which are Methods). If you
* want a method, use `obj.type()->getMethod()`
*/
IValue getAttr(const std::string& name) const;
void setAttr(const std::string& name, IValue v);

std::string name() const;

const std::vector<IValue>& slots() const {
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,8 @@ TypePtr incompleteInferTypeFrom(const IValue& value) {
return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
} else if (value.isDevice()) {
return DeviceObjType::get();
} else if (value.isObject()) {
return value.toObject()->type();
}
AT_ERROR("Type cannot be accurately recovered from this IValue.");
}
Expand Down
3 changes: 2 additions & 1 deletion test/cpp/jit/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ namespace jit {
_(SubgraphMatching) \
_(ModuleDefine) \
_(QualifiedName) \
_(ClassImport)
_(ClassImport) \
_(ScriptObject)

#define TH_FORALL_TESTS_CUDA(_) \
_(ArgumentSpec) \
Expand Down
38 changes: 37 additions & 1 deletion test/cpp/jit/test_class_import.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include <ATen/core/qualified_name.h>
#include <test/cpp/jit/test_base.h>
#include <test/cpp/jit/test_utils.h>

#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/torch.h>

namespace torch {
namespace jit {
Expand All @@ -13,10 +16,12 @@ op_version_set = 1
class FooNestedTest:
def __init__(self, y):
self.y = y

class FooNestedTest2:
def __init__(self, y):
self.y = y
self.nested = __torch__.FooNestedTest(y)

class FooTest:
def __init__(self, x):
self.class_attr = __torch__.FooNestedTest(x)
Expand Down Expand Up @@ -58,6 +63,37 @@ void testClassImport() {
ASSERT_FALSE(c);
}

void testScriptObject() {
Module m1;
Module m2;
std::vector<at::Tensor> constantTable;
import_libs(
m1.class_compilation_unit(),
"__torch__",
classSrcs1,
constantTable,
nullptr);
import_libs(
m2.class_compilation_unit(),
"__torch__",
classSrcs2,
constantTable,
nullptr);

// Incorrect arguments for constructor should throw
c10::QualifiedName base("__torch__");
ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest"), {1}));
auto x = torch::ones({2, 3});
auto obj = m2.create_class(c10::QualifiedName(base, "FooTest"), x).toObject();
auto dx = obj->getAttr("dx");
ASSERT_TRUE(test::almostEqual(x, dx.toTensor()));

auto new_x = torch::rand({2, 3});
obj->setAttr("dx", new_x);
auto new_dx = obj->getAttr("dx");
ASSERT_TRUE(test::almostEqual(new_x, new_dx.toTensor()));
}

} // namespace script
} // namespace jit
} // namespace torch
63 changes: 47 additions & 16 deletions torch/csrc/jit/script/module.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
#include <torch/csrc/jit/script/module.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/export.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/script/compiler.h>
#include <torch/csrc/jit/script/error_report.h>
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/schema_matching.h>
#include <torch/csrc/autograd/generated/variable_factories.h>

namespace torch {
namespace jit {
Expand Down Expand Up @@ -106,15 +106,15 @@ void module_state_to(
const c10::optional<at::Device>& device,
const c10::optional<at::ScalarType>& dtype,
bool non_blocking) {
// Need to access the `at::Tensor` as a `Variable` here.
autograd::Variable variable = s.value().toTensor();
at::Tensor data = variable.data();
// Use the data's original device or dtype if not supplied here.
auto new_data = data.to(
device.value_or(data.device()),
dtype.value_or(data.scalar_type()),
non_blocking);
variable.set_data(new_data);
// Need to access the `at::Tensor` as a `Variable` here.
autograd::Variable variable = s.value().toTensor();
at::Tensor data = variable.data();
// Use the data's original device or dtype if not supplied here.
auto new_data = data.to(
device.value_or(data.device()),
dtype.value_or(data.scalar_type()),
non_blocking);
variable.set_data(new_data);
}

void Module::to_impl(
Expand Down Expand Up @@ -301,7 +301,6 @@ void Module::copy_into(
}
}


void Module::clone_method(
const Module& orig,
const std::string& name,
Expand Down Expand Up @@ -348,10 +347,42 @@ void Module::clone_method(const Module& orig, const std::string& name) {
}

void Module::train(bool on) {
for (auto& submod : get_modules()) {
submod->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
for (auto& submod : get_modules()) {
submod->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
}

IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
// Classes live in the top-level compilation unit.
if (parent_) {
return parent_->create_class(name, std::move(stack));
}

// Look up the class
const auto classType =
class_compilation_unit().get_class(c10::QualifiedName(name));
if (!classType) {
AT_ERROR(
"Could not find class with name: '",
name.qualifiedName(),
"' in module.");
}

// Create a bare object with correct number of slots
const size_t numAttrs = classType->numAttributes();
auto obj = c10::ivalue::Object::create(classType, numAttrs);

// Invoke the `__init__()` of the class with the arguments provided.
Stack stackWithSelf = {obj};
for (auto& arg : stack) {
stackWithSelf.push_back(std::move(arg));
}
// Note: following Python, `__init__()` modifies its first parameter in-place
// and returns nothing.
classType->getMethod("__init__")->operator()(std::move(stackWithSelf));

return obj;
}

} // namespace script
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/script/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,14 @@ struct TORCH_API Module {
// so that C++ users can easily add methods
void define(const std::string& src, const ResolverPtr& resolver = nullptr);

template <typename... Types>
IValue create_class(const c10::QualifiedName& name, Types&&... args)
const {
return create_class(name, {IValue(std::forward<Types>(args))...});
}

IValue create_class(const c10::QualifiedName& name, Stack stack) const;

private:
std::pair<std::shared_ptr<Function>, std::vector<Slot>>
lower_first_class_method(Function* fn);
Expand Down