Skip to content

Commit 236c2b2

Browse files
wanchaolfacebook-github-bot
authored andcommitted
Let script module buffer attributes can also cast device/type (#19700)
Summary: Tested locally this fix #19039, did not add a test since there's no way to create a script module in the cpp world. Pull Request resolved: #19700 Differential Revision: D15094195 Pulled By: wanchaol fbshipit-source-id: fcc2c1e5efbc160d976ae485ba2457442f62f065
1 parent 5099db0 commit 236c2b2

File tree

3 files changed

+47
-8
lines changed

3 files changed

+47
-8
lines changed

test/cpp/jit/test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ namespace jit {
8080
_(ArgumentSpec) \
8181
_(Fusion) \
8282
_(GraphExecutor) \
83+
_(ModuleConversion) \
8384
_(Interp)
8485

8586
#if defined(USE_GTEST)

test/cpp/jit/test_misc.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,6 +758,30 @@ void testModuleDefine() {
758758
AT_ASSERT(result.toTensor().item<float>() == 6)
759759
}
760760

761+
void testModuleConversion() {
762+
auto m = std::make_shared<script::Module>();
763+
{
764+
// test cuda to cpu for params and buffers
765+
m->register_parameter("foo", torch::ones({}, at::kCUDA), false);
766+
m->register_buffer("bar", torch::ones({}, at::kCUDA));
767+
768+
m->to(at::kCUDA);
769+
m->to(at::kCPU);
770+
AT_ASSERT(m->get_parameter("foo").data().device().is_cpu());
771+
AT_ASSERT(m->get_buffer("bar").data().device().is_cpu());
772+
}
773+
{
774+
// test cpu to cuda for params and buffers
775+
m->register_parameter("foo", torch::ones({}), false);
776+
m->register_buffer("bar", torch::ones({}));
777+
778+
m->to(at::kCUDA);
779+
AT_ASSERT(m->get_parameter("foo").data().device().is_cuda());
780+
AT_ASSERT(m->get_buffer("bar").data().device().is_cuda());
781+
}
782+
}
783+
784+
761785
static int testPassValue = 0;
762786
void fakePass(std::shared_ptr<Graph>& g) {
763787
testPassValue++;

torch/csrc/jit/script/module.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,39 @@ void Module::save(
101101
ExportModule(*this, filename, extra_files);
102102
}
103103

104-
void Module::to_impl(
104+
void module_state_to(
105+
const Slot& s,
105106
const c10::optional<at::Device>& device,
106107
const c10::optional<at::ScalarType>& dtype,
107108
bool non_blocking) {
108-
// First call `to()` on every child module.
109-
for (auto& child : get_modules()) {
110-
child->to_impl(device, dtype, non_blocking);
111-
}
112-
// Then convert every of our parameters.
113-
for (auto& parameter : get_parameters()) {
114109
// Need to access the `at::Tensor` as a `Variable` here.
115-
autograd::Variable variable = parameter.value().toTensor();
110+
autograd::Variable variable = s.value().toTensor();
116111
at::Tensor data = variable.data();
117112
// Use the data's original device or dtype if not supplied here.
118113
auto new_data = data.to(
119114
device.value_or(data.device()),
120115
dtype.value_or(data.scalar_type()),
121116
non_blocking);
122117
variable.set_data(new_data);
118+
}
119+
120+
void Module::to_impl(
121+
const c10::optional<at::Device>& device,
122+
const c10::optional<at::ScalarType>& dtype,
123+
bool non_blocking) {
124+
// First call `to()` on every child module.
125+
for (auto& child : get_modules()) {
126+
child->to_impl(device, dtype, non_blocking);
127+
}
128+
// Then convert every of our parameters.
129+
for (auto& parameter : get_parameters()) {
130+
module_state_to(parameter, device, dtype, non_blocking);
131+
}
132+
// Then convert every tensor attributes (buffers).
133+
for (auto& attr : get_attributes()) {
134+
if (attr.type()->isSubtypeOf(TensorType::get())) {
135+
module_state_to(attr, device, dtype, non_blocking);
136+
}
123137
}
124138
}
125139

0 commit comments

Comments
 (0)