Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,26 +1245,32 @@ def test_constructor_dtypes(self):
default_type = torch.Tensor().type()
self.assertIs(torch.Tensor().dtype, torch.Tensor.dtype)

torch.set_default_tensor_type('torch.IntTensor')
self.assertIs(torch.int32, torch.Tensor.dtype)
self.assertIs(torch.int32, torch.IntTensor.dtype)
self.assertEqual(torch.IntStorage, torch.Storage)
torch.set_default_tensor_type('torch.FloatTensor')
self.assertIs(torch.float32, torch.Tensor.dtype)
self.assertIs(torch.float32, torch.FloatTensor.dtype)
self.assertEqual(torch.FloatStorage, torch.Storage)

torch.set_default_tensor_type(torch.int64)
self.assertIs(torch.int64, torch.Tensor.dtype)
self.assertIs(torch.int64, torch.LongTensor.dtype)
self.assertEqual(torch.LongStorage, torch.Storage)
torch.set_default_tensor_type(torch.float64)
self.assertIs(torch.float64, torch.Tensor.dtype)
self.assertIs(torch.float64, torch.DoubleTensor.dtype)
self.assertEqual(torch.DoubleStorage, torch.Storage)

torch.set_default_tensor_type('torch.Tensor')
self.assertIs(torch.int64, torch.Tensor.dtype)
self.assertIs(torch.int64, torch.LongTensor.dtype)
self.assertEqual(torch.LongStorage, torch.Storage)
self.assertIs(torch.float64, torch.Tensor.dtype)
self.assertIs(torch.float64, torch.DoubleTensor.dtype)
self.assertEqual(torch.DoubleStorage, torch.Storage)

if torch.cuda.is_available():
torch.set_default_tensor_type(torch.cuda.float64)
self.assertIs(torch.cuda.float64, torch.Tensor.dtype)
self.assertIs(torch.cuda.float64, torch.cuda.DoubleTensor.dtype)
self.assertEqual(torch.cuda.DoubleStorage, torch.Storage)
torch.set_default_tensor_type(torch.cuda.float32)
self.assertIs(torch.cuda.float32, torch.Tensor.dtype)
self.assertIs(torch.cuda.float32, torch.cuda.FloatTensor.dtype)
self.assertEqual(torch.cuda.FloatStorage, torch.Storage)

# don't support integral or sparse default types.
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type('torch.IntTensor'))
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.int64))
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.sparse.int64))
self.assertRaises(TypeError, lambda: torch.set_default_tensor_type(torch.sparse.float64))

torch.set_default_tensor_type(default_type)

Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/tensor/python_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,14 @@ void py_set_default_tensor_type(PyObject* obj) {
throw unavailable_type(*type);
}

if (!at::isFloatingType(type->aten_type->scalarType())) {
throw TypeError("only floating-point types are supported as the default type");
}

if (type->aten_type->is_sparse()) {
throw TypeError("only dense types are supported as the default type");
}

// get the storage first, so if it doesn't exist we don't change the default tensor type
THPObjectPtr storage = get_storage_obj(*type);
set_default_tensor_type(*type->aten_type);
Expand Down