Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,28 @@ def test_div(self):
res2[i, 3] = res2[i, 3] / 2
self.assertEqual(res1, res2)

def test_floordiv(self):
for dtype in torch.testing.get_all_dtypes():
if dtype is torch.float16:
continue
x = torch.randn(100).mul(10).to(dtype)
y = x // 3
self.assertEqual(y.dtype, x.dtype)
z = torch.tensor([math.trunc(v.item() / 3.) for v in x], dtype=y.dtype)
self.assertEqual(y, z)

def test_rdiv(self):
for dtype in torch.testing.get_all_dtypes():
if dtype is torch.float16:
continue
x = torch.rand(100).add(1).mul(4).to(dtype)
y = 30 / x
if dtype.is_floating_point:
z = torch.tensor([30 / v.item() for v in x], dtype=dtype)
else:
z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype)
self.assertEqual(y, z)

def test_fmod(self):
m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
res1 = m1.clone()
Expand Down
18 changes: 17 additions & 1 deletion torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,11 @@ def __rsub__(self, other):
return -self + other

def __rdiv__(self, other):
return self.reciprocal() * other
if self.dtype.is_floating_point:
return self.reciprocal() * other
else:
return (self.double().reciprocal() * other).type_as(self)

This comment was marked as off-topic.

This comment was marked as off-topic.


__rtruediv__ = __rdiv__
__itruediv__ = _C._TensorBase.__idiv__

Expand All @@ -334,6 +338,18 @@ def __ipow__(self, other):
def __rpow__(self, other):
return self.new([other]) ** self

def __floordiv__(self, other):
result = self / other
if result.dtype.is_floating_point:
result = result.trunc()
return result

def __rfloordiv__(self, other):
result = other / self
if result.dtype.is_floating_point:
result = result.trunc()
return result

__neg__ = _C._TensorBase.neg

__eq__ = _C._TensorBase.eq
Expand Down