Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3487,6 +3487,7 @@ def forward(self, x):
self.assertEqual(4, w(3))
w.train(False)
self.assertEqual(7, w(3))
self.assertFalse("training" in w.state_dict())

def test_jitter_bug(self):
@torch.jit.script
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/jit/script/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ void initJitScriptBindings(PyObject* module) {
})
.def("_register_module", &Module::register_module)
.def("_register_buffer", &Module::register_buffer)
.def(
"_set_attribute",
[](Module& self, const std::string& name, py::object value) {
auto attr = self.find_attribute(name);
AT_CHECK(attr != nullptr, "Could not find attribute '", name, "'");
auto ivalue = toIValue(value, attr->type());
attr->setValue(ivalue);
})
.def("_set_parameter", &Module::set_parameter)
.def("_get_parameter", &Module::get_parameter)
.def("_get_buffer", &Module::get_buffer)
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/jit/script/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,11 @@ void Module::train(bool on) {
for (auto& submod : get_modules()) {
submod->train(on);
}
register_buffer("training", torch::tensor(on ? 1 : 0, at::kLong));
if (auto slot = find_attribute("training")) {
slot->setValue(on);
} else {
register_attribute("training", BoolType::get(), on);
}
}

IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/script/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,8 @@ struct TORCH_API Module {
}
/// True if the module is in training mode.
bool is_training() {
if (auto p = find_buffer("training")) {
return p->value().toTensor().item<int64_t>() == 1;
if (auto p = find_attribute("training")) {
return p->value().toBool();
}
// We are in training mode by default
return true;
Expand Down
12 changes: 5 additions & 7 deletions torch/csrc/jit/script/python_sugared_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,14 @@ std::shared_ptr<SugaredValue> ModuleValue::attr(
// it adds a buffer 'training' to the model if one doesn't exist
// and then loads that parameter, casting it to bool
if (field == "training") {
Slot* v = module_->find_buffer(field);
Slot* v = module_->find_attribute(field);
if (!v) {
bool training = py::cast<bool>(py::getattr(py_module_, "training"));
auto t =
autograd::make_variable(at::full({}, training ? 1 : 0, at::kLong));
module_->register_buffer("training", std::move(t));
v = module_->find_buffer(field);
module_->register_attribute(
"training", BoolType::get(), std::move(training));
v = module_->find_attribute(field);
}
Value* the_tensor = m.graph()->insertGetAttr(self_, "training");
Value* the_bool = m.graph()->insert(prim::Bool, {the_tensor});
Value* the_bool = m.graph()->insertGetAttr(self_, "training");
return std::make_shared<SimpleValue>(the_bool);
}

Expand Down
9 changes: 4 additions & 5 deletions torch/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,8 @@ def get_debug_state(self, *args, **kwargs):
def __getattr__(self, attr):
if '_c' not in self.__dict__:
raise RuntimeError("ScriptModule has not been initialized, did you forget to call super's init?")
if self._c._has_attribute(attr):
return self._c._get_attribute(attr)
if self._c._has_method(attr):
if attr in self.__class__._methods:
original_method = self.__class__._methods[attr].original_method
Expand All @@ -1452,9 +1454,6 @@ def __getattr__(self, attr):
# to improve invocation performance
self.__dict__[attr] = script_method
return script_method

if self._c._has_attribute(attr):
return self._c._get_attribute(attr)
return Module.__getattr__(self, attr)

def __setattr__(self, attr, value):
Expand All @@ -1463,9 +1462,9 @@ def __setattr__(self, attr, value):
# Compile weak script module
value = _make_strong(value)
if attr == 'training':
if self._c._has_buffer('training'):
if self._c._has_attribute('training'):
self.__dict__['training'] = value
self._c._get_buffer('training').fill_(int(value))
self._c._set_attribute('training', value)
return
if isinstance(value, Attribute):
the_type = torch.jit.annotations.ann_to_type(value.type)
Expand Down
9 changes: 8 additions & 1 deletion torch/nn/parallel/replicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,17 @@ def replicate(network, devices, detach=False):
# we have to initialize ScriptModule properly so that
# it works with pybind11
replica = _init_script_module()
keys = set(module.__dict__.keys()) - scriptmodule_skip_attr

attribute_names = set(entry[0] for entry in module._c._get_attributes())

keys = set(module.__dict__.keys()) - scriptmodule_skip_attr - attribute_names
for key in keys:
if not _is_script_method(module.__dict__[key]):
replica.__dict__[key] = module.__dict__[key]
for name, the_type, value in module._c._get_attributes():
if name in module._buffers.keys():
continue
replica._c._register_attribute(name, the_type, value)
else:
replica = module.__new__(type(module))
replica.__dict__ = module.__dict__.copy()
Expand Down