Skip to content

Commit c0a419e

Browse files
authored
Add non_blocking to Tensor/Module.to (#7312)
* Add non_blocking to Tensor/Module.to * flake8 * Add argparse tests * cpp parse * Use C++ parser * use a commong parse function with Tensor.to * fix test_jit * use THPObjectPtr * increase refcount for None, True, and False * address comments * address comments
1 parent ec4a0f3 commit c0a419e

File tree

14 files changed

+178
-112
lines changed

14 files changed

+178
-112
lines changed

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def get_lstm_inputs(device):
7979
input = torch.randn(3, 10, dtype=torch.float, device=device)
8080
hx = torch.randn(3, 20, dtype=torch.float, device=device)
8181
cx = torch.randn(3, 20, dtype=torch.float, device=device)
82-
module = nn.LSTMCell(10, 20).to(torch.float, device) # Just to allocate weights with correct sizes
82+
module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
8383
return (input, hx, cx) + tuple(p.requires_grad_(False) for p in module.parameters())
8484

8585

test/test_nn.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,28 @@ def test_add_module(self):
11751175
self.assertEqual(net.l, l3)
11761176
self.assertRaises(TypeError, lambda: net.add_module('x', 'non-module'))
11771177

1178+
def test_module_to_argparse(self):
1179+
net = nn.Sequential(nn.Linear(3, 3))
1180+
cpu = torch.device('cpu')
1181+
with self.assertRaises(TypeError):
1182+
net.to(cpu, True)
1183+
with self.assertRaises(TypeError):
1184+
net.to(torch.long)
1185+
with self.assertRaises(TypeError):
1186+
net.to(None, True)
1187+
with self.assertRaises(TypeError):
1188+
net.to(cpu, torch.long, True)
1189+
with self.assertRaises(TypeError):
1190+
net.to(cpu, dtype=torch.long, non_blocking=True)
1191+
with self.assertRaises(TypeError):
1192+
net.to([])
1193+
with self.assertRaises(TypeError):
1194+
net.to({}, non_blocking=True)
1195+
with self.assertRaises(TypeError):
1196+
net.to(torch.tensor(3, dtype=torch.long), non_blocking=True)
1197+
with self.assertRaises(TypeError):
1198+
net.to(cpu, torch.tensor(3, dtype=torch.long), non_blocking=True)
1199+
11781200
def test_type(self):
11791201
l = nn.Linear(10, 20)
11801202
net = nn.Module()
@@ -1203,22 +1225,22 @@ def test_type(self):
12031225
self.assertIsInstance(l.weight.data, torch.FloatTensor)
12041226
self.assertIsInstance(l.bias.data, torch.FloatTensor)
12051227
self.assertIsInstance(net.indices, torch.LongTensor)
1206-
net.to("cuda", torch.double)
1228+
net.to("cuda", torch.double, True)
12071229
self.assertIsInstance(l.weight.data, torch.cuda.DoubleTensor)
12081230
self.assertIsInstance(l.bias.data, torch.cuda.DoubleTensor)
12091231
self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1210-
net.to(device="cuda:0", dtype=torch.half)
1232+
net.to(torch.empty(1, device="cuda:0", dtype=torch.half))
12111233
self.assertIsInstance(l.weight.data, torch.cuda.HalfTensor)
12121234
self.assertIsInstance(l.bias.data, torch.cuda.HalfTensor)
12131235
self.assertIsInstance(net.indices, torch.cuda.LongTensor)
1214-
net.to(torch.device("cpu"))
1236+
net.to(torch.device("cpu"), non_blocking=True)
12151237
self.assertIsInstance(l.weight.data, torch.HalfTensor)
12161238
self.assertIsInstance(l.bias.data, torch.HalfTensor)
12171239
self.assertIsInstance(net.indices, torch.LongTensor)
12181240
net.type(torch.FloatTensor)
12191241
self.assertIsInstance(l.weight.data, torch.FloatTensor)
12201242
self.assertIsInstance(l.bias.data, torch.FloatTensor)
1221-
net.type(torch.DoubleTensor)
1243+
net.to(torch.DoubleTensor(1))
12221244
self.assertIsInstance(l.weight.data, torch.DoubleTensor)
12231245
self.assertIsInstance(l.bias.data, torch.DoubleTensor)
12241246
if TEST_CUDA:

test/test_torch.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1858,26 +1858,28 @@ def test_to(self):
18581858
self.assertIs(torch.float32, a.to(dtype=torch.float32).dtype)
18591859

18601860
if torch.cuda.is_available():
1861-
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
1862-
b = torch.tensor(5., device=cuda)
1863-
self.assertEqual(b.device, b.to(cuda).device)
1864-
self.assertEqual(a.device, b.to('cpu').device)
1865-
self.assertEqual(b.device, a.to(cuda).device)
1866-
self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32).dtype)
1867-
self.assertEqual(a.device, b.to('cpu', dtype=torch.int32).device)
1868-
self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
1869-
self.assertEqual(b.device, b.to(dtype=torch.int32).device)
1861+
for non_blocking in [True, False]:
1862+
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
1863+
b = torch.tensor(5., device=cuda)
1864+
self.assertEqual(b.device, b.to(cuda, non_blocking=non_blocking).device)
1865+
self.assertEqual(a.device, b.to('cpu', non_blocking=non_blocking).device)
1866+
self.assertEqual(b.device, a.to(cuda, non_blocking=non_blocking).device)
1867+
self.assertIs(torch.int32, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
1868+
self.assertEqual(a.device, b.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
1869+
self.assertIs(torch.int32, b.to(dtype=torch.int32).dtype)
1870+
self.assertEqual(b.device, b.to(dtype=torch.int32).device)
18701871

18711872
def test_to_with_tensor(self):
18721873
a = torch.tensor(5)
18731874
self.assertEqual(a.device, a.to(a).device)
18741875

18751876
if torch.cuda.is_available():
1876-
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
1877-
b = torch.tensor(5., device=cuda)
1878-
self.assertEqual(b.device, b.to(b).device)
1879-
self.assertEqual(a.device, b.to(a).device)
1880-
self.assertEqual(b.device, a.to(b).device)
1877+
for non_blocking in [True, False]:
1878+
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
1879+
b = torch.tensor(5., device=cuda)
1880+
self.assertEqual(b.device, b.to(b, non_blocking=non_blocking).device)
1881+
self.assertEqual(a.device, b.to(a, non_blocking=non_blocking).device)
1882+
self.assertEqual(b.device, a.to(b, non_blocking=non_blocking).device)
18811883

18821884
@staticmethod
18831885
def _test_empty_full(self, dtypes, layout, device):

tools/autograd/templates/python_nn_functions.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22

33
// ${generated_comment}
44

5+
#include "torch/csrc/Device.h"
6+
#include "torch/csrc/DynamicTypes.h"
57
#include "torch/csrc/Exceptions.h"
68
#include "torch/csrc/autograd/python_variable.h"
79
#include "torch/csrc/autograd/utils/wrap_outputs.h"
10+
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
811
#include "torch/csrc/utils/python_arg_parser.h"
912

1013
#include "python_nn_functions_dispatch.h"
@@ -15,9 +18,36 @@ using namespace torch::autograd::utils;
1518

1619
namespace torch { namespace autograd {
1720

21+
static PyObject * THPVariable__parse_to(PyObject* module, PyObject* args, PyObject* kwargs)
22+
{
23+
HANDLE_TH_ERRORS
24+
auto parsed = parse_to_conversion(args, kwargs);
25+
auto& device = std::get<0>(parsed);
26+
auto& scalarType = std::get<1>(parsed);
27+
auto non_blocking = std::get<2>(parsed);
28+
auto tuple = THPObjectPtr{PyTuple_New(3)};
29+
if (!tuple) throw python_error();
30+
if (device) {
31+
PyTuple_SET_ITEM(tuple.get(), 0, THPDevice_New(*device));
32+
} else {
33+
Py_INCREF(Py_None);
34+
PyTuple_SET_ITEM(tuple.get(), 0, Py_None);
35+
}
36+
if (scalarType) {
37+
PyTuple_SET_ITEM(tuple.get(), 1, torch::autograd::utils::wrap(torch::getDtype(*scalarType)));
38+
} else {
39+
Py_INCREF(Py_None);
40+
PyTuple_SET_ITEM(tuple.get(), 1, Py_None);
41+
}
42+
PyTuple_SET_ITEM(tuple.get(), 2, torch::autograd::utils::wrap(non_blocking));
43+
return tuple.release();
44+
END_HANDLE_TH_ERRORS
45+
}
46+
1847
${py_methods}
1948

2049
static PyMethodDef nn_functions[] = {
50+
{"_parse_to", (PyCFunction)THPVariable__parse_to, METH_VARARGS | METH_KEYWORDS, nullptr},
2151
${py_method_defs}
2252
{NULL}
2353
};

tools/autograd/templates/python_variable_methods.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "torch/csrc/autograd/python_variable.h"
99
#include "torch/csrc/autograd/utils/python_error_messages.h"
1010
#include "torch/csrc/autograd/utils/wrap_outputs.h"
11+
#include "torch/csrc/autograd/utils/python_arg_parsing.h"
1112
#include "torch/csrc/jit/tracer.h"
1213
#ifdef WITH_CUDA
1314
#include "torch/csrc/cuda/Stream.h"
@@ -558,31 +559,22 @@ static PyObject * THPVariable_storage_type(PyObject* self, PyObject* arg)
558559
static PyObject * THPVariable_to(PyObject* self, PyObject* args, PyObject* kwargs)
559560
{
560561
HANDLE_TH_ERRORS
561-
static PythonArgParser parser({
562-
"to(Device device, ScalarType dtype=None)",
563-
"to(ScalarType dtype)",
564-
"to(Tensor other)",
565-
});
566-
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
567-
ParsedArgs<2> parsed_args;
568-
auto r = parser.parse(args, kwargs, parsed_args);
569-
if (r.idx == 0) {
570-
auto device = r.device(0);
571-
auto deviceAutoGPU = device.deviceInt64();
572-
auto scalarType = r.scalartypeWithDefault(1, self_.type().scalarType());
573-
auto& layout = *torch::getLayout(self_.type().backend());
574-
auto& type = torch::getType(scalarType, layout, device.type);
575-
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
576-
} else if (r.idx == 1) {
577-
auto scalarType = r.scalartype(0);
578-
auto& type = self_.type().toScalarType(scalarType);
562+
auto parsed = parse_to_conversion(args, kwargs);
563+
auto& device = std::get<0>(parsed);
564+
auto& scalarType = std::get<1>(parsed);
565+
auto non_blocking = std::get<2>(parsed);
566+
if (!device) {
567+
// device not given
568+
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
569+
auto& type = self_.type().toScalarType(scalarType.value_or(self_.type().scalarType()));
579570
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type));
580-
} else if (r.idx == 2) {
581-
auto other = r.tensor(0);
582-
auto& type = other.type();
583-
auto deviceType = torch::getDeviceType(type);
584-
auto deviceAutoGPU = (deviceType == DeviceType::CPU) ? -1 : other.get_device();
585-
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, false));
571+
} else {
572+
// device and maybe dtype are given
573+
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
574+
auto deviceAutoGPU = device->deviceInt64();
575+
auto& layout = *torch::getLayout(self_.type().backend());
576+
auto& type = torch::getType(scalarType.value_or(self_.type().scalarType()), layout, device->type);
577+
return THPVariable_Wrap(torch::utils::dispatch_type_conversion(self_, type, deviceAutoGPU, non_blocking));
586578
}
587579
Py_RETURN_NONE;
588580
END_HANDLE_TH_ERRORS

torch/_tensor_docs.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2003,15 +2003,20 @@ def callable(a, b) -> number
20032003
20042004
Returns a Tensor with the specified :attr:`dtype`
20052005
2006-
.. function:: to(device, dtype=None) -> Tensor
2006+
.. function:: to(device=None, dtype=None, non_blocking=False) -> Tensor
20072007
20082008
Returns a Tensor with the specified :attr:`device` and (optional)
20092009
:attr:`dtype`. If :attr:`dtype` is ``None`` it is inferred to be ``self.dtype``.
2010+
When :attr:`non_blocking`, tries to convert asynchronously with respect to
2011+
the host if possible, e.g., converting a CPU Tensor with pinned memory to a
2012+
CUDA Tensor.
20102013
2011-
.. function:: to(other) -> Tensor
2014+
.. function:: to(other, non_blocking=False) -> Tensor
20122015
2013-
Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as the Tensor
2014-
:attr:`other`.
2016+
Returns a Tensor with same :class:`torch.dtype` and :class:`torch.device` as
2017+
the Tensor :attr:`other`. When :attr:`non_blocking`, tries to convert
2018+
asynchronously with respect to the host if possible, e.g., converting a CPU
2019+
Tensor with pinned memory to a CUDA Tensor.
20152020
20162021
Example::
20172022
@@ -2030,7 +2035,7 @@ def callable(a, b) -> number
20302035
[ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')
20312036
20322037
>>> other = torch.randn((), dtype=torch.float64, device=cuda0)
2033-
>>> tensor.to(other)
2038+
>>> tensor.to(other, non_blocking=True)
20342039
tensor([[-0.5044, 0.0005],
20352040
[ 0.3310, -0.0584]], dtype=torch.float64, device='cuda:0')
20362041

torch/csrc/autograd/python_variable.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -366,32 +366,24 @@ PyObject *THPVariable_is_sparse(THPVariable *self)
366366
END_HANDLE_TH_ERRORS
367367
}
368368

369-
PyObject *THPVariable_dtype(THPVariable *self)
369+
static PyObject *THPVariable_dtype(THPVariable *self)
370370
{
371371
HANDLE_TH_ERRORS
372372
auto& self_ = self->cdata;
373373
return torch::autograd::utils::wrap(torch::getDtype(self_.type().scalarType()));
374374
END_HANDLE_TH_ERRORS
375375
}
376376

377-
static PyObject * THPVariable_layout(THPVariable* self, PyObject* args) {
377+
static PyObject * THPVariable_layout(THPVariable* self) {
378378
HANDLE_TH_ERRORS
379379
auto& self_ = self->cdata;
380380
return torch::autograd::utils::wrap(torch::getLayout(self_.type().backend()));
381381
END_HANDLE_TH_ERRORS
382382
}
383383

384-
static PyObject * THPVariable_device(THPVariable* self, PyObject* args) {
384+
static PyObject * THPVariable_device(THPVariable* self) {
385385
HANDLE_TH_ERRORS
386-
auto& self_ = self->cdata;
387-
if (self_.type().is_cuda()) {
388-
torch::Device device(torch::DeviceType::CUDA, self_.get_device(), false);
389-
return THPDevice_New(device);
390-
}
391-
else {
392-
torch::Device device(torch::DeviceType::CPU, -1, true);
393-
return THPDevice_New(device);
394-
}
386+
return THPDevice_New(torch::tensor::getDevice(self->cdata));
395387
END_HANDLE_TH_ERRORS
396388
}
397389

@@ -413,9 +405,9 @@ static struct PyGetSetDef THPVariable_properties[] = {
413405
{"shape", (getter)THPVariable_get_shape, nullptr, nullptr, nullptr},
414406
{"is_cuda", (getter)THPVariable_is_cuda, nullptr, nullptr, nullptr},
415407
{"is_sparse", (getter)THPVariable_is_sparse, nullptr, nullptr, nullptr},
416-
{"dtype", (getter)THPVariable_dtype, NULL, NULL, NULL},
417-
{"layout", (getter)THPVariable_layout, NULL, NULL, NULL},
418-
{"device", (getter)THPVariable_device, NULL, NULL, NULL},
408+
{"dtype", (getter)THPVariable_dtype, nullptr, nullptr, nullptr},
409+
{"layout", (getter)THPVariable_layout, nullptr, nullptr, nullptr},
410+
{"device", (getter)THPVariable_device, nullptr, nullptr, nullptr},
419411
{nullptr}
420412
};
421413

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
3+
#include "torch/csrc/python_headers.h"
4+
#include <ATen/ATen.h>
5+
6+
#include "torch/csrc/utils/python_arg_parser.h"
7+
#include "torch/csrc/utils/device.h"
8+
9+
namespace torch { namespace autograd { namespace utils {
10+
11+
inline std::tuple<at::optional<torch::Device>, at::optional<at::ScalarType>, bool>
12+
parse_to_conversion(PyObject *args, PyObject *kwargs) {
13+
static PythonArgParser parser({
14+
"to(Device device=None, ScalarType dtype=None, bool non_blocking=False)",
15+
"to(ScalarType dtype, bool non_blocking=False)",
16+
"to(Tensor tensor, bool non_blocking=False)",
17+
});
18+
ParsedArgs<3> parsed_args;
19+
auto r = parser.parse(args, kwargs, parsed_args);
20+
if (r.idx == 0) {
21+
return std::make_tuple(r.deviceOptional(0), r.scalartypeOptional(1), r.toBool(2));
22+
} else if (r.idx == 1) {
23+
return std::make_tuple(at::nullopt, r.scalartype(0), r.toBool(1));
24+
} else {
25+
auto tensor = r.tensor(0);
26+
return std::make_tuple(
27+
torch::tensor::getDevice(tensor),
28+
tensor.type().scalarType(),
29+
r.toBool(1)
30+
);
31+
}
32+
}
33+
34+
}}} // namespace torch::autograd::utils

torch/csrc/tensor/python_tensor.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,4 +384,12 @@ at::Type& get_default_tensor_type() {
384384
return *default_tensor_type;
385385
}
386386

387+
Device getDevice(const at::Tensor& tensor) {
388+
if (tensor.type().is_cuda()) {
389+
return torch::Device(torch::DeviceType::CUDA, tensor.get_device(), false);
390+
} else {
391+
return torch::Device(torch::DeviceType::CPU, -1, true);
392+
}
393+
}
394+
387395
}} // namespace torch::tensor

torch/csrc/tensor/python_tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "torch/csrc/python_headers.h"
4+
#include "torch/csrc/utils/device.h"
45
#include <ATen/ATen.h>
56

67
namespace torch { namespace tensor {
@@ -23,4 +24,7 @@ void py_set_default_dtype(PyObject* dtype_obj);
2324
// returned value will be a VariableType instance.
2425
at::Type& get_default_tensor_type();
2526

27+
// Gets the torch::Device object of a given at::Tensor
28+
Device getDevice(const at::Tensor& tensor);
29+
2630
}} // namespace torch::tensor

0 commit comments

Comments
 (0)