Skip to content

Commit ce03348

Browse files
suofacebook-github-bot
authored andcommitted
Convenience APIs for script objects (#20226)
Summary: Pull Request resolved: #20226 ghimport-source-id: 22937d72e35ec4eba38019284a368453089fe3eb Differential Revision: D15243625 Pulled By: suo fbshipit-source-id: 5e9fb773da244f9ef201dba524155c3b19b2b4e0
1 parent 50149fb commit ce03348

File tree

7 files changed

+134
-18
lines changed

7 files changed

+134
-18
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ std::string ivalue::Object::name() const {
113113
return this->type_->qualname();
114114
}
115115

116+
IValue ivalue::Object::getAttr(const std::string& name) const {
117+
const size_t slot = type_->getAttributeSlot(name);
118+
return getSlot(slot);
119+
}
120+
121+
void ivalue::Object::setAttr(const std::string& name, IValue v) {
122+
const size_t slot = type_->getAttributeSlot(name);
123+
setSlot(slot, std::move(v));
124+
}
125+
116126
void ivalue::Object::resizeObject(size_t slot) {
117127
AT_ASSERT(slot < type()->numAttributes());
118128
slots_.resize(type()->numAttributes());

aten/src/ATen/core/ivalue.h

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@
1212

1313
#include <ATen/core/Tensor.h>
1414

15+
namespace torch {
16+
namespace jit {
17+
namespace script {
18+
struct Function;
19+
}
20+
} // namespace jit
21+
} // namespace torch
1522
namespace c10 {
1623
struct IValue;
1724
struct ClassType;
@@ -695,6 +702,14 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
695702
return c10::make_intrusive<Object>(std::move(type), numSlots);
696703
}
697704

705+
/**
706+
* Slot API.
707+
*
708+
* Attributes are stored as a simple vector so that lookups are fast at
709+
* runtime. A "slot" is just an index into that vector, which can be computed
710+
* statically if you have access to the class type. Use this API if you are
711+
* writing compiler stuff.
712+
*/
698713
void setSlot(size_t slot, IValue v) {
699714
if (slot >= slots_.size()) {
700715
// for module types, it is possible that the members of the class have
@@ -709,6 +724,19 @@ struct C10_EXPORT ivalue::Object final : c10::intrusive_ptr_target {
709724
return slots_.at(slot);
710725
}
711726

727+
/**
728+
* Attribute API.
729+
*
730+
* Wrappers around the slot stuff so that users can access attributes
731+
* directly. Use this API if you are a user.
732+
*
733+
* Note: Unlike in Python, TorchScript must make a distinction between
734+
* attributes (which are IValues) and methods (which are Methods). If you
735+
* want a method, use `obj.type()->getMethod()`
736+
*/
737+
IValue getAttr(const std::string& name) const;
738+
void setAttr(const std::string& name, IValue v);
739+
712740
std::string name() const;
713741

714742
const std::vector<IValue>& slots() const {

aten/src/ATen/core/type.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ TypePtr incompleteInferTypeFrom(const IValue& value) {
156156
return TupleType::create(fmap(value.toTuple()->elements(), incompleteInferTypeFrom));
157157
} else if (value.isDevice()) {
158158
return DeviceObjType::get();
159+
} else if (value.isObject()) {
160+
return value.toObject()->type();
159161
}
160162
AT_ERROR("Type cannot be accurately recovered from this IValue.");
161163
}

test/cpp/jit/test.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ namespace jit {
8080
_(SubgraphMatching) \
8181
_(ModuleDefine) \
8282
_(QualifiedName) \
83-
_(ClassImport)
83+
_(ClassImport) \
84+
_(ScriptObject)
8485

8586
#define TH_FORALL_TESTS_CUDA(_) \
8687
_(ArgumentSpec) \

test/cpp/jit/test_class_import.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#pragma once
22

3-
#include <ATen/core/qualified_name.h>
43
#include <test/cpp/jit/test_base.h>
4+
#include <test/cpp/jit/test_utils.h>
5+
6+
#include <ATen/core/qualified_name.h>
57
#include <torch/csrc/jit/import_source.h>
8+
#include <torch/torch.h>
69

710
namespace torch {
811
namespace jit {
@@ -13,10 +16,12 @@ op_version_set = 1
1316
class FooNestedTest:
1417
def __init__(self, y):
1518
self.y = y
19+
1620
class FooNestedTest2:
1721
def __init__(self, y):
1822
self.y = y
1923
self.nested = __torch__.FooNestedTest(y)
24+
2025
class FooTest:
2126
def __init__(self, x):
2227
self.class_attr = __torch__.FooNestedTest(x)
@@ -58,6 +63,37 @@ void testClassImport() {
5863
ASSERT_FALSE(c);
5964
}
6065

66+
void testScriptObject() {
67+
Module m1;
68+
Module m2;
69+
std::vector<at::Tensor> constantTable;
70+
import_libs(
71+
m1.class_compilation_unit(),
72+
"__torch__",
73+
classSrcs1,
74+
constantTable,
75+
nullptr);
76+
import_libs(
77+
m2.class_compilation_unit(),
78+
"__torch__",
79+
classSrcs2,
80+
constantTable,
81+
nullptr);
82+
83+
// Incorrect arguments for constructor should throw
84+
c10::QualifiedName base("__torch__");
85+
ASSERT_ANY_THROW(m1.create_class(c10::QualifiedName(base, "FooTest"), {1}));
86+
auto x = torch::ones({2, 3});
87+
auto obj = m2.create_class(c10::QualifiedName(base, "FooTest"), x).toObject();
88+
auto dx = obj->getAttr("dx");
89+
ASSERT_TRUE(test::almostEqual(x, dx.toTensor()));
90+
91+
auto new_x = torch::rand({2, 3});
92+
obj->setAttr("dx", new_x);
93+
auto new_dx = obj->getAttr("dx");
94+
ASSERT_TRUE(test::almostEqual(new_x, new_dx.toTensor()));
95+
}
96+
6197
} // namespace script
6298
} // namespace jit
6399
} // namespace torch

torch/csrc/jit/script/module.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
#include <torch/csrc/jit/script/module.h>
21
#include <c10/util/Exception.h>
2+
#include <torch/csrc/autograd/generated/variable_factories.h>
33
#include <torch/csrc/jit/export.h>
44
#include <torch/csrc/jit/operator.h>
55
#include <torch/csrc/jit/passes/dead_code_elimination.h>
66
#include <torch/csrc/jit/script/compiler.h>
77
#include <torch/csrc/jit/script/error_report.h>
8+
#include <torch/csrc/jit/script/module.h>
89
#include <torch/csrc/jit/script/schema_matching.h>
9-
#include <torch/csrc/autograd/generated/variable_factories.h>
1010

1111
namespace torch {
1212
namespace jit {
@@ -106,15 +106,15 @@ void module_state_to(
106106
const c10::optional<at::Device>& device,
107107
const c10::optional<at::ScalarType>& dtype,
108108
bool non_blocking) {
109-
// Need to access the `at::Tensor` as a `Variable` here.
110-
autograd::Variable variable = s.value().toTensor();
111-
at::Tensor data = variable.data();
112-
// Use the data's original device or dtype if not supplied here.
113-
auto new_data = data.to(
114-
device.value_or(data.device()),
115-
dtype.value_or(data.scalar_type()),
116-
non_blocking);
117-
variable.set_data(new_data);
109+
// Need to access the `at::Tensor` as a `Variable` here.
110+
autograd::Variable variable = s.value().toTensor();
111+
at::Tensor data = variable.data();
112+
// Use the data's original device or dtype if not supplied here.
113+
auto new_data = data.to(
114+
device.value_or(data.device()),
115+
dtype.value_or(data.scalar_type()),
116+
non_blocking);
117+
variable.set_data(new_data);
118118
}
119119

120120
void Module::to_impl(
@@ -301,7 +301,6 @@ void Module::copy_into(
301301
}
302302
}
303303

304-
305304
void Module::clone_method(
306305
const Module& orig,
307306
const std::string& name,
@@ -348,10 +347,42 @@ void Module::clone_method(const Module& orig, const std::string& name) {
348347
}
349348

350349
void Module::train(bool on) {
351-
for (auto& submod : get_modules()) {
352-
submod->train(on);
353-
}
354-
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
350+
for (auto& submod : get_modules()) {
351+
submod->train(on);
352+
}
353+
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
354+
}
355+
356+
IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
357+
// Classes live in the top-level compilation unit.
358+
if (parent_) {
359+
return parent_->create_class(name, std::move(stack));
360+
}
361+
362+
// Look up the class
363+
const auto classType =
364+
class_compilation_unit().get_class(c10::QualifiedName(name));
365+
if (!classType) {
366+
AT_ERROR(
367+
"Could not find class with name: '",
368+
name.qualifiedName(),
369+
"' in module.");
370+
}
371+
372+
// Create a bare object with correct number of slots
373+
const size_t numAttrs = classType->numAttributes();
374+
auto obj = c10::ivalue::Object::create(classType, numAttrs);
375+
376+
// Invoke the `__init__()` of the class with the arguments provided.
377+
Stack stackWithSelf = {obj};
378+
for (auto& arg : stack) {
379+
stackWithSelf.push_back(std::move(arg));
380+
}
381+
// Note: following Python, `__init__()` modifies its first parameter in-place
382+
// and returns nothing.
383+
classType->getMethod("__init__")->operator()(std::move(stackWithSelf));
384+
385+
return obj;
355386
}
356387

357388
} // namespace script

torch/csrc/jit/script/module.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,14 @@ struct TORCH_API Module {
442442
// so that C++ users can easily add methods
443443
void define(const std::string& src, const ResolverPtr& resolver = nullptr);
444444

445+
template <typename... Types>
446+
IValue create_class(const c10::QualifiedName& name, Types&&... args)
447+
const {
448+
return create_class(name, {IValue(std::forward<Types>(args))...});
449+
}
450+
451+
IValue create_class(const c10::QualifiedName& name, Stack stack) const;
452+
445453
private:
446454
std::pair<std::shared_ptr<Function>, std::vector<Slot>>
447455
lower_first_class_method(Function* fn);

0 commit comments

Comments
 (0)