Skip to content

Commit 036f618

Browse files
committed
address comments
1 parent b6fddcf commit 036f618

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -563,21 +563,18 @@ static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwarg
563563
auto& device = std::get<0>(parsed);
564564
auto& scalarType = std::get<1>(parsed);
565565
auto non_blocking = std::get<2>(parsed);
566-
if (!device && scalarType) {
567-
// only dtype given
566+
if (!device) {
567+
// device not given
568568
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
569-
auto& type = self_.type().toScalarType(*scalarType);
569+
auto& type = self_.type().toScalarType(scalarType.value_or(self_.type().scalarType()));
570570
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type));
571-
} else if (device) {
572-
// device and maybe dtype given
571+
} else {
572+
// device and maybe dtype are given
573573
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
574574
auto deviceAutoGPU = device->deviceInt64();
575575
auto& layout = *torch::getLayout(self_.type().backend());
576576
auto& type = torch::getType(scalarType.value_or(self_.type().scalarType()), layout, device->type);
577577
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, non_blocking));
578-
} else {
579-
Py_INCREF(self);
580-
return self;
581578
}
582579
Py_RETURN_NONE;
583580
END_HANDLE_TH_ERRORS

0 commit comments

Comments
 (0)