Skip to content

Commit da65cb2

Browse files
committed
Add support for __floordiv__ and __rdiv__ for integral tensors
1 parent 4ab6ea5 commit da65cb2

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torch/tensor.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,11 @@ def __rsub__(self, other):
317317
return -self + other
318318

319319
def __rdiv__(self, other):
320-
return self.reciprocal() * other
320+
if self.dtype.is_floating_point:
321+
return self.reciprocal() * other
322+
else:
323+
return (self.double().reciprocal() * other).type_as(self)
324+
321325
__rtruediv__ = __rdiv__
322326
__itruediv__ = _C._TensorBase.__idiv__
323327

@@ -334,6 +338,18 @@ def __ipow__(self, other):
334338
def __rpow__(self, other):
335339
return self.new([other]) ** self
336340

341+
def __floordiv__(self, other):
342+
result = self / other
343+
if result.dtype.is_floating_point:
344+
result = result.trunc()
345+
return result
346+
347+
def __rfloordiv__(self, other):
348+
result = other / self
349+
if result.dtype.is_floating_point:
350+
result = result.trunc()
351+
return result
352+
337353
__neg__ = _C._TensorBase.neg
338354

339355
__eq__ = _C._TensorBase.eq

0 commit comments

Comments
 (0)