File tree Expand file tree Collapse file tree 2 files changed +39
-1
lines changed
Expand file tree Collapse file tree 2 files changed +39
-1
lines changed Original file line number Diff line number Diff line change @@ -1003,6 +1003,28 @@ def test_div(self):
10031003 res2 [i , 3 ] = res2 [i , 3 ] / 2
10041004 self .assertEqual (res1 , res2 )
10051005
1006+ def test_floordiv (self ):
1007+ for dtype in torch .testing .get_all_dtypes ():
1008+ if dtype is torch .float16 :
1009+ continue
1010+ x = torch .randn (100 ).mul (10 ).to (dtype )
1011+ y = x // 3
1012+ self .assertEqual (y .dtype , x .dtype )
1013+ z = torch .tensor ([math .trunc (v .item () / 3. ) for v in x ], dtype = y .dtype )
1014+ self .assertEqual (y , z )
1015+
1016+ def test_rdiv (self ):
1017+ for dtype in torch .testing .get_all_dtypes ():
1018+ if dtype is torch .float16 :
1019+ continue
1020+ x = torch .rand (100 ).add (1 ).mul (4 ).to (dtype )
1021+ y = 30 / x
1022+ if dtype .is_floating_point :
1023+ z = torch .tensor ([30 / v .item () for v in x ], dtype = dtype )
1024+ else :
1025+ z = torch .tensor ([math .trunc (30. / v .item ()) for v in x ], dtype = dtype )
1026+ self .assertEqual (y , z )
1027+
10061028 def test_fmod (self ):
10071029 m1 = torch .Tensor (10 , 10 ).uniform_ (- 10. , 10. )
10081030 res1 = m1 .clone ()
Original file line number Diff line number Diff line change @@ -317,7 +317,11 @@ def __rsub__(self, other):
317317 return - self + other
318318
319319 def __rdiv__ (self , other ):
320- return self .reciprocal () * other
320+ if self .dtype .is_floating_point :
321+ return self .reciprocal () * other
322+ else :
323+ return (self .double ().reciprocal () * other ).type_as (self )
324+
321325 __rtruediv__ = __rdiv__
322326 __itruediv__ = _C ._TensorBase .__idiv__
323327
@@ -334,6 +338,18 @@ def __ipow__(self, other):
334338 def __rpow__ (self , other ):
335339 return self .new ([other ]) ** self
336340
341+ def __floordiv__ (self , other ):
342+ result = self / other
343+ if result .dtype .is_floating_point :
344+ result = result .trunc ()
345+ return result
346+
347+ def __rfloordiv__ (self , other ):
348+ result = other / self
349+ if result .dtype .is_floating_point :
350+ result = result .trunc ()
351+ return result
352+
337353 __neg__ = _C ._TensorBase .neg
338354
339355 __eq__ = _C ._TensorBase .eq
You can’t perform that action at this time.
0 commit comments