Skip to content

Commit 2b8998c

Browse files
committed
Add tests
1 parent da65cb2 commit 2b8998c

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/test_torch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,28 @@ def test_div(self):
10031003
res2[i, 3] = res2[i, 3] / 2
10041004
self.assertEqual(res1, res2)
10051005

1006+
def test_floordiv(self):
1007+
for dtype in torch.testing.get_all_dtypes():
1008+
if dtype is torch.float16:
1009+
continue
1010+
x = torch.randn(100).mul(10).to(dtype)
1011+
y = x // 3
1012+
self.assertEqual(y.dtype, x.dtype)
1013+
z = torch.tensor([math.trunc(v.item() / 3) for v in x], dtype=y.dtype)
1014+
self.assertEqual(y, z)
1015+
1016+
def test_rdiv(self):
1017+
for dtype in torch.testing.get_all_dtypes():
1018+
if dtype is torch.float16:
1019+
continue
1020+
x = torch.rand(10).add(1).mul(4).to(dtype)
1021+
y = 30 / x
1022+
if dtype.is_floating_point:
1023+
z = torch.tensor([30 / v.item() for v in x], dtype=dtype)
1024+
else:
1025+
z = torch.tensor([math.trunc(30. / v.item()) for v in x], dtype=dtype)
1026+
self.assertEqual(y, z)
1027+
10061028
def test_fmod(self):
10071029
m1 = torch.Tensor(10, 10).uniform_(-10., 10.)
10081030
res1 = m1.clone()

0 commit comments

Comments
 (0)