Skip to content

Commit 8091388

Browse files
authored
Add support for __floordiv__ and __rdiv__ for integral tensors (#7245)
1 parent 371cc1e commit 8091388

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

test/test_torch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,28 @@ def test_div(self):
10031003
res2[i, 3] = res2[i, 3] / 2
10041004
self.assertEqual(res1, res2)
10051005

1006+
def test_floordiv(self):
1007+
for dtype in torch.testing.get_all_dtypes():
1008+
if dtype is torch.float16:
1009+
continue
1010+
x = torch.randn(100).mul(10).to(dtype)
1011+
y = x // 3
1012+
self.assertEqual(y.dtype, x.dtype)
1013+
z = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=y.dtype)
1014+
self.assertEqual(y, z)
1015+
1016+
def test_rdiv(self):
1017+
for dtype in torch.testing.get_all_dtypes():
1018+
if dtype is torch.float16:
1019+
continue
1020+
x = torch.rand(100).add(1).mul(4).to(dtype)
1021+
y = 30 / x
1022+
if dtype.is_floating_point:
1023+
z = torch.tensor([30 / v.item() for v in x], dtype=dtype)
1024+
else:
1025+
z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype)
1026+
self.assertEqual(y, z)
1027+
10061028
def test_fmod(self):
10071029
m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
10081030
res1 = m1.clone()

torch/tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ def __rsub__(self, other):
317317
return -self + other
318318

319319
def __rdiv__(self, other):
320-
return self.reciprocal() * other
320+
if self.dtype.is_floating_point:
321+
return self.reciprocal() * other
322+
else:
323+
return (self.double().reciprocal() * other).type_as(self)
324+
321325
__rtruediv__ = __rdiv__
322326
__itruediv__ = _C._TensorBase.__idiv__
323327

@@ -334,6 +338,18 @@ def __ipow__(self, other):
334338
def __rpow__(self, other):
335339
return self.new([other]) ** self
336340

341+
def __floordiv__(self, other):
342+
result = self / other
343+
if result.dtype.is_floating_point:
344+
result = result.trunc()
345+
return result
346+
347+
def __rfloordiv__(self, other):
348+
result = other / self
349+
if result.dtype.is_floating_point:
350+
result = result.trunc()
351+
return result
352+
337353
__neg__ = _C._TensorBase.neg
338354

339355
__eq__ = _C._TensorBase.eq

0 commit comments

Comments
 (0)