Skip to content

Commit 9b31280

Browse files
lantigaezyang
authored andcommitted
Have __sizeof__ account for size of stored elements (#3821)
* Have __sizeof__ account for size of stored elements * Conform to sizeof specification
1 parent 558516f commit 9b31280

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

test/test_torch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4349,6 +4349,19 @@ def test_print(self):
43494349
x.__repr__()
43504350
str(x),
43514351

4352+
def test_sizeof(self):
4353+
sizeof_empty = torch.randn(0).storage().__sizeof__()
4354+
sizeof_10 = torch.randn(10).storage().__sizeof__()
4355+
sizeof_100 = torch.randn(100).storage().__sizeof__()
4356+
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
4357+
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
4358+
4359+
sizeof_empty = torch.randn(0).type(torch.ByteTensor).storage().__sizeof__()
4360+
sizeof_10 = torch.randn(10).type(torch.ByteTensor).storage().__sizeof__()
4361+
sizeof_100 = torch.randn(100).type(torch.ByteTensor).storage().__sizeof__()
4362+
self.assertEqual((sizeof_100 - sizeof_empty) // (sizeof_10 - sizeof_empty), 10)
4363+
self.assertEqual((sizeof_100 - sizeof_empty) % (sizeof_10 - sizeof_empty), 0)
4364+
43524365
def test_unsqueeze(self):
43534366
x = torch.randn(2, 3, 4)
43544367
y = x.unsqueeze(1)

torch/storage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ def __deepcopy__(self, memo):
3030
def __reduce__(self):
3131
return type(self), (self.tolist(),)
3232

33+
def __sizeof__(self):
34+
return super(_StorageBase, self).__sizeof__() + self.element_size() * self.size()
35+
3336
def clone(self):
3437
"""Returns a copy of this storage"""
3538
return type(self)(self.size()).copy_(self)

0 commit comments

Comments
 (0)