|
1 | | -#include <torch/csrc/jit/script/module.h> |
2 | 1 | #include <c10/util/Exception.h> |
| 2 | +#include <torch/csrc/autograd/generated/variable_factories.h> |
3 | 3 | #include <torch/csrc/jit/export.h> |
4 | 4 | #include <torch/csrc/jit/operator.h> |
5 | 5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | 6 | #include <torch/csrc/jit/script/compiler.h> |
7 | 7 | #include <torch/csrc/jit/script/error_report.h> |
| 8 | +#include <torch/csrc/jit/script/module.h> |
8 | 9 | #include <torch/csrc/jit/script/schema_matching.h> |
9 | | -#include <torch/csrc/autograd/generated/variable_factories.h> |
10 | 10 |
|
11 | 11 | namespace torch { |
12 | 12 | namespace jit { |
@@ -106,15 +106,15 @@ void module_state_to( |
106 | 106 | const c10::optional<at::Device>& device, |
107 | 107 | const c10::optional<at::ScalarType>& dtype, |
108 | 108 | bool non_blocking) { |
109 | | - // Need to access the `at::Tensor` as a `Variable` here. |
110 | | - autograd::Variable variable = s.value().toTensor(); |
111 | | - at::Tensor data = variable.data(); |
112 | | - // Use the data's original device or dtype if not supplied here. |
113 | | - auto new_data = data.to( |
114 | | - device.value_or(data.device()), |
115 | | - dtype.value_or(data.scalar_type()), |
116 | | - non_blocking); |
117 | | - variable.set_data(new_data); |
| 109 | + // Need to access the `at::Tensor` as a `Variable` here. |
| 110 | + autograd::Variable variable = s.value().toTensor(); |
| 111 | + at::Tensor data = variable.data(); |
| 112 | + // Use the data's original device or dtype if not supplied here. |
| 113 | + auto new_data = data.to( |
| 114 | + device.value_or(data.device()), |
| 115 | + dtype.value_or(data.scalar_type()), |
| 116 | + non_blocking); |
| 117 | + variable.set_data(new_data); |
118 | 118 | } |
119 | 119 |
|
120 | 120 | void Module::to_impl( |
@@ -301,7 +301,6 @@ void Module::copy_into( |
301 | 301 | } |
302 | 302 | } |
303 | 303 |
|
304 | | - |
305 | 304 | void Module::clone_method( |
306 | 305 | const Module& orig, |
307 | 306 | const std::string& name, |
@@ -348,10 +347,42 @@ void Module::clone_method(const Module& orig, const std::string& name) { |
348 | 347 | } |
349 | 348 |
|
350 | 349 | void Module::train(bool on) { |
351 | | - for (auto& submod : get_modules()) { |
352 | | - submod->train(on); |
353 | | - } |
354 | | - register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong)); |
| 350 | + for (auto& submod : get_modules()) { |
| 351 | + submod->train(on); |
| 352 | + } |
| 353 | + register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong)); |
| 354 | +} |
| 355 | + |
| 356 | +IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const { |
| 357 | + // Classes live in the top-level compilation unit. |
| 358 | + if (parent_) { |
| 359 | + return parent_->create_class(name, std::move(stack)); |
| 360 | + } |
| 361 | + |
| 362 | + // Look up the class |
| 363 | + const auto classType = |
| 364 | + class_compilation_unit().get_class(c10::QualifiedName(name)); |
| 365 | + if (!classType) { |
| 366 | + AT_ERROR( |
| 367 | + "Could not find class with name: '", |
| 368 | + name.qualifiedName(), |
| 369 | + "' in module."); |
| 370 | + } |
| 371 | + |
| 372 | + // Create a bare object with correct number of slots |
| 373 | + const size_t numAttrs = classType->numAttributes(); |
| 374 | + auto obj = c10::ivalue::Object::create(classType, numAttrs); |
| 375 | + |
| 376 | + // Invoke the `__init__()` of the class with the arguments provided. |
| 377 | + Stack stackWithSelf = {obj}; |
| 378 | + for (auto& arg : stack) { |
| 379 | + stackWithSelf.push_back(std::move(arg)); |
| 380 | + } |
| 381 | + // Note: following Python, `__init__()` modifies its first parameter in-place |
| 382 | + // and returns nothing. |
| 383 | + classType->getMethod("__init__")->operator()(std::move(stackWithSelf)); |
| 384 | + |
| 385 | + return obj; |
355 | 386 | } |
356 | 387 |
|
357 | 388 | } // namespace script |
|
0 commit comments