Skip to content

Commit 361648a

Browse files
authored
Fix torch.tensor(...) device-type calculation when used with numpy an… (#6995)
* Fix torch.tensor(...) device-type calculation when used with numpy and type inference. * Fix tensor device type inference as well. * Better variable type inference: infer cuda-ness only if device is not specified.
1 parent 0c737df commit 361648a

File tree

6 files changed

+61
-24
lines changed

6 files changed

+61
-24
lines changed

test/test_sparse.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,6 +891,19 @@ def test_factory_type_inference(self):
891891
t = torch.sparse_coo_tensor(torch.tensor(([0], [2])), torch.tensor([1]))
892892
self.assertEqual(torch.int64, t.dtype)
893893

894+
@cuda_only
895+
def test_factory_device_type_inference(self):
896+
# both indices/values are CUDA
897+
shape = (1, 3)
898+
for indices_device in ['cuda', 'cpu']:
899+
for values_device in ['cuda', 'cpu']:
900+
for sparse_device in ['cuda', 'cpu', None]:
901+
t = torch.sparse_coo_tensor(torch.tensor(([0], [2]), device=indices_device),
902+
torch.tensor([1.], device=values_device),
903+
(1, 3), device=sparse_device)
904+
should_be_cuda = sparse_device == 'cuda' or (sparse_device is None and values_device == 'cuda')
905+
self.assertEqual(should_be_cuda, t.is_cuda)
906+
894907
@cpu_only
895908
def test_factory_copy(self):
896909
# both correct

test/test_torch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,6 +1745,10 @@ def assertEqual(device_str, fn):
17451745
# NOTE: 'cpu' is the canonical representation of 'cpu:0', but 'cuda:X' is the canonical
17461746
# representation of cuda devices.
17471747
assertEqual('cpu', lambda: torch.ones((2, 3), dtype=torch.float32, device='cpu:0'))
1748+
assertEqual('cpu', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cpu:0'))
1749+
if TEST_NUMPY:
1750+
assertEqual('cpu', lambda: torch.tensor(np.random.randn(2, 3), device='cpu'))
1751+
17481752
if torch.cuda.is_available():
17491753
assertEqual('cuda:0', lambda: torch.tensor(5).cuda(0))
17501754
assertEqual('cuda:0', lambda: torch.tensor(5).cuda('cuda:0'))
@@ -1754,12 +1758,18 @@ def assertEqual(device_str, fn):
17541758
assertEqual('cuda:0', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:0'))
17551759
assertEqual('cuda:' + str(torch.cuda.current_device()),
17561760
lambda: torch.tensor(5, dtype=torch.int64, device='cuda'))
1761+
assertEqual('cuda:0', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:0'))
1762+
if TEST_NUMPY:
1763+
assertEqual('cuda:0', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:0'))
17571764

17581765
if torch.cuda.device_count() > 1:
17591766
assertEqual('cuda:1', lambda: torch.tensor(5).cuda(1))
17601767
assertEqual('cuda:1', lambda: torch.tensor(5).cuda('cuda:1'))
17611768
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device=1))
17621769
assertEqual('cuda:1', lambda: torch.tensor(5, dtype=torch.int64, device='cuda:1'))
1770+
assertEqual('cuda:1', lambda: torch.tensor(torch.ones((2, 3), dtype=torch.float32), device='cuda:1'))
1771+
if TEST_NUMPY:
1772+
assertEqual('cuda:1', lambda: torch.tensor(np.random.randn(2, 3), device='cuda:1'))
17631773

17641774
def test_to(self):
17651775
a = torch.tensor(5)

torch/csrc/autograd/python_variable_indexing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static Variable applySelect(const Variable& self, int64_t dim, int64_t index) {
102102

103103
static Variable sequenceToVariable(const Type& type, PyObject* seq) {
104104
auto& idx_type = type.toScalarType(kLong);
105-
return torch::utils::legacy_new_from_data(idx_type, -1, seq);
105+
return torch::utils::legacy_new_from_data(idx_type, at::nullopt, seq);
106106
}
107107

108108
static Variable valueToTensor(const Type & type, PyObject* value) {

torch/csrc/utils/python_arg_parser.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ struct PythonArgs {
101101
inline Device device(int i);
102102
inline Device deviceWithDefault(int i, const Device& default_device);
103103
inline int64_t deviceInt64(int i);
104+
inline at::optional<Device> deviceOptional(int i);
104105
inline std::string string(int i);
105106
inline PyObject* pyobject(int i);
106107
inline int64_t toInt64(int i);
@@ -332,6 +333,11 @@ inline int64_t PythonArgs::deviceInt64(int i) {
332333
return dev.deviceInt64();
333334
}
334335

336+
inline at::optional<Device> PythonArgs::deviceOptional(int i) {
337+
if (!args[i]) return at::nullopt;
338+
return device(i);
339+
}
340+
335341
inline std::string PythonArgs::string(int i) {
336342
if (!args[i]) return "";
337343
return THPUtils_unpackString(args[i]);

torch/csrc/utils/tensor_new.cpp

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -171,24 +171,32 @@ static void recursive_store(char* data, IntList sizes, IntList strides, int64_t
171171
}
172172
}
173173

174-
static Tensor internal_new_from_data(const Type & type, int device, PyObject* data,
174+
static Tensor internal_new_from_data(const Type & type, at::optional<Device> device_opt, PyObject* data,
175175
bool copy_variables, bool copy_numpy,
176176
bool type_inference) {
177+
int64_t device = device_opt.has_value() ? device_opt.value().deviceInt64() : -1;
177178
if (THPUtils_checkString(data)) {
178179
throw TypeError("new(): invalid data type '%s'", Py_TYPE(data)->tp_name);
179180
}
180181

181182
if (THPVariable_Check(data)) {
182183
auto var = reinterpret_cast<THPVariable*>(data)->cdata;
183-
const auto& type_to_use = type_inference ? var.type() : type;
184+
auto type_inference_device_type = device_opt.has_value() ? device_opt.value().type
185+
: torch::getDeviceType(var.type());
186+
// infer the scalar type and device type; it's not expected to infer the layout since these constructors
187+
// are defined per-layout-type (e.g. tensor vs sparse_coo_tensor).
188+
const auto& type_inference_type = torch::getType(var.type().scalarType(),
189+
*torch::getLayout(type.backend()),
190+
type_inference_device_type);
191+
const auto& type_to_use = type_inference ? type_inference_type : type;
184192
return copy_variables ? new_with_tensor_copy(type_to_use, var, device) :
185193
new_with_type_conversion(type_to_use, var, device);
186194
}
187195

188196
#ifdef WITH_NUMPY
189197
if (PyArray_Check(data)) {
190198
auto tensor = autograd::make_variable(tensor_from_numpy(data), /*requires_grad=*/false);
191-
const auto& type_to_use = type_inference ? tensor.type() : type;
199+
const auto& type_to_use = type_inference ? type.toScalarType(tensor.type().scalarType()) : type;
192200
return copy_numpy ? new_with_tensor_copy(type_to_use, tensor, device) :
193201
new_with_type_conversion(type_to_use, tensor, device);
194202
}
@@ -204,15 +212,15 @@ static Tensor internal_new_from_data(const Type & type, int device, PyObject* da
204212
return new_with_type_conversion(type_to_use, tensor, device);
205213
}
206214

207-
Tensor legacy_new_from_data(const Type & type, int device, PyObject *data) {
215+
Tensor legacy_new_from_data(const Type & type, at::optional<Device> device, PyObject *data) {
208216
return internal_new_from_data(type, device, data, false, false, false);
209217
}
210218

211-
static Tensor new_from_data_copy(const Type & type, int device, PyObject *data) {
219+
static Tensor new_from_data_copy(const Type & type, at::optional<Device> device, PyObject *data) {
212220
return internal_new_from_data(type, device, data, true, true, false);
213221
}
214222

215-
static Tensor legacy_new_from_sequence(const Type & type, int device, PyObject* data) {
223+
static Tensor legacy_new_from_sequence(const Type & type, at::optional<Device> device, PyObject* data) {
216224
if (!PySequence_Check(data)) {
217225
throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name);
218226
}
@@ -246,7 +254,7 @@ static Tensor legacy_sparse_tensor_ctor(const Type& type, PyObject* args, PyObje
246254
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
247255
// new(sequence) binds to this signature but should be treated differently
248256
// unless the sequences is a torch.Size
249-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
257+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
250258
}
251259
return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
252260
}
@@ -284,11 +292,11 @@ Tensor legacy_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
284292
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
285293
// new(sequence) binds to this signature but should be treated differently
286294
// unless the sequences is a torch.Size
287-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
295+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
288296
}
289297
return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
290298
} else if (r.idx == 5) {
291-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
299+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
292300
}
293301
throw std::runtime_error("new(): invalid arguments");
294302
}
@@ -324,7 +332,7 @@ static Tensor legacy_sparse_tensor_new(const Type& type, PyObject* args, PyObjec
324332
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
325333
// new(sequence) binds to this signature but should be treated differently
326334
// unless the sequences is a torch.Size
327-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
335+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
328336
}
329337
return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
330338
}
@@ -362,11 +370,11 @@ Tensor legacy_tensor_new(const Type& type, PyObject* args, PyObject* kwargs) {
362370
if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) {
363371
// new(sequence) binds to this signature but should be treated differently
364372
// unless the sequences is a torch.Size
365-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
373+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
366374
}
367375
return new_with_sizes(type, r.deviceInt64(1), r.intlist(0));
368376
} else if (r.idx == 5) {
369-
return legacy_new_from_sequence(type, r.deviceInt64(1), r.pyobject(0));
377+
return legacy_new_from_sequence(type, r.deviceOptional(1), r.pyobject(0));
370378
}
371379
throw std::runtime_error("new(): invalid arguments");
372380
}
@@ -398,22 +406,21 @@ Tensor sparse_coo_tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs
398406
bool type_inference = r.isNone(2);
399407
const auto& sparse_type = typeWithDefault(r, 2, 3, default_sparse_type);
400408
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
401-
const auto& index_type = dense_type.toScalarType(kLong);
402409
AutoGPU autogpu(r.deviceInt64(3));
403-
// explanation of booleans: allow variables, do type conversion of them, copy numpy data
404-
Tensor indices = internal_new_from_data(index_type, -1, r.pyobject(0), false, true, false);
405-
Tensor values = internal_new_from_data(dense_type, -1, r.pyobject(1), false, true, type_inference);
410+
Tensor values = internal_new_from_data(dense_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference);
411+
// if no dtype provided, infer type based on value type.
412+
const auto& index_type = values.type().toScalarType(kLong);
413+
Tensor indices = internal_new_from_data(index_type, r.deviceOptional(3), r.pyobject(0), false, true, false);
406414
const auto& sparse_type_to_use = values.type().toBackend(values.type().is_cuda() ? kSparseCUDA : kSparseCPU);
407415
return set_requires_grad(sparse_type_to_use.sparse_coo_tensor(indices, values), r.toBool(4));
408416
} else if (r.idx == 1) {
409417
bool type_inference = r.isNone(3);
410418
const auto& sparse_type = typeWithDefault(r, 3, 4, default_sparse_type);
411419
const auto& dense_type = sparse_type.toBackend(sparse_type.is_cuda() ? kCUDA : kCPU);
412-
const auto& index_type = dense_type.toScalarType(kLong);
413420
AutoGPU autogpu(r.deviceInt64(4));
414-
// explanation of booleans: allow variables, do type conversion of them, copy numpy data
415-
Tensor indices = internal_new_from_data(index_type, -1, r.pyobject(0), false, true, false);
416-
Tensor values = internal_new_from_data(dense_type, -1, r.pyobject(1), false, true, type_inference);
421+
Tensor values = internal_new_from_data(dense_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference);
422+
const auto& index_type = values.type().toScalarType(kLong);
423+
Tensor indices = internal_new_from_data(index_type, r.deviceOptional(4), r.pyobject(0), false, true, false);
417424
const auto& sparse_type_to_use = values.type().toBackend(values.type().is_cuda() ? kSparseCUDA : kSparseCPU);
418425
return set_requires_grad(sparse_type_to_use.sparse_coo_tensor(indices, values, r.intlist(2)), r.toBool(5));
419426
}
@@ -430,7 +437,7 @@ Tensor tensor_ctor(const Type& type, PyObject* args, PyObject* kwargs) {
430437
if (r.idx == 0) {
431438
bool type_inference = r.isNone(1);
432439
return set_requires_grad(internal_new_from_data(
433-
typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
440+
typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0), true, true, type_inference), r.toBool(3));
434441
}
435442
throw std::runtime_error("tensor(): invalid arguments");
436443
}
@@ -445,7 +452,7 @@ Tensor new_tensor(const Type& type, PyObject* args, PyObject* kwargs) {
445452
auto r = parser.parse(args, kwargs, parsed_args);
446453
if (r.idx == 0) {
447454
return set_requires_grad(new_from_data_copy(
448-
typeWithDefault(r, 1, 2, type), r.deviceInt64(2), r.pyobject(0)), r.toBool(3));
455+
typeWithDefault(r, 1, 2, type), r.deviceOptional(2), r.pyobject(0)), r.toBool(3));
449456
}
450457
throw std::runtime_error("new_tensor(): invalid arguments");
451458
}

torch/csrc/utils/tensor_new.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#pragma once
22

33
#include "torch/csrc/python_headers.h"
4+
#include "torch/csrc/utils/device.h"
45
#include <ATen/ATen.h>
56

67
namespace torch { namespace utils {
78

89
at::Tensor legacy_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
910
at::Tensor legacy_tensor_new(const at::Type& type, PyObject* args, PyObject* kwargs);
10-
at::Tensor legacy_new_from_data(const at::Type& type, int device, PyObject *data);
11+
at::Tensor legacy_new_from_data(const at::Type& type, at::optional<Device> device, PyObject *data);
1112
at::Tensor sparse_coo_tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
1213
at::Tensor tensor_ctor(const at::Type& type, PyObject* args, PyObject* kwargs);
1314
at::Tensor new_tensor(const at::Type& type, PyObject* args, PyObject* kwargs);

0 commit comments

Comments
 (0)