Skip to content

Commit c4572aa

Browse files
authored
[MPS] Add fixes for div with floor (#95869)
* [MPS] Add fixes for div with floor and raise error for div_trunc (#95769) Fixes #ISSUE_NUMBER Pull Request resolved: #95769 Approved by: https://github.com/DenisVieriu97 * Add back the unittest skip for MacOS 12.
1 parent 82b078b commit c4572aa

File tree

4 files changed

+20
-17
lines changed

4 files changed

+20
-17
lines changed

aten/src/ATen/native/mps/MPSGraphVenturaOps.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,7 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
138138
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode
139139
constantValue:(double) constantValue
140140
name:(NSString * _Nullable) name;
141+
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
142+
name:(NSString * _Nullable) name;
143+
141144
@end

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,6 @@ void div_mode_template(const Tensor& self, const Tensor& other,
177177
c10::optional<c10::string_view> rounding_mode,
178178
const Tensor& output, const string op_name)
179179
{
180-
if(rounding_mode.has_value() && *rounding_mode == "floor"){
181-
TORCH_CHECK(self.scalar_type() != ScalarType::Long,
182-
"MPS: does not support floor_divide op with int64 input");
183-
}
184180
BinaryOpBlock div_mode_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
185181
MPSGraph* mpsGraph = cachedGraph->graph();
186182
bool isFloatInput = ([primaryCastTensor dataType] & MPSDataTypeFloatBit) != 0;

aten/src/ATen/native/mps/operations/UnaryOps.mm

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,20 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
7575
return inputTensor;
7676
}
7777

78-
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
79-
dataType:inputTensor.dataType];
80-
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
81-
secondaryTensor:zeroTensor
82-
name:nil];
83-
return [mpsGraph selectWithPredicateTensor:predicateTensor
84-
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
85-
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
86-
name:nil];
78+
if(!is_macos_13_or_newer()) {
79+
MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0
80+
dataType:inputTensor.dataType];
81+
MPSGraphTensor* predicateTensor = [mpsGraph lessThanWithPrimaryTensor:inputTensor
82+
secondaryTensor:zeroTensor
83+
name:nil];
84+
return [mpsGraph selectWithPredicateTensor:predicateTensor
85+
truePredicateTensor:[mpsGraph ceilWithTensor :inputTensor name:nil]
86+
falsePredicateTensor:[mpsGraph floorWithTensor:inputTensor name:nil]
87+
name:nil];
88+
} else {
89+
return [mpsGraph truncateWithTensor:inputTensor
90+
name:nil];
91+
}
8792
};
8893

8994
} // namespace mps

test/test_mps.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2344,10 +2344,9 @@ def test_full_bugs(self):
23442344
# See https://github.com/pytorch/pytorch/issues/84995
23452345
def test_div_bugs(self):
23462346
for (dtype, mode) in itertools.product(integral_types(), ['trunc', 'floor']):
2347-
if dtype != torch.int64:
2348-
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
2349-
y = torch.div(x, 101, rounding_mode=mode)
2350-
self.assertEqual(y.sum(), 0)
2347+
x = torch.tensor(list(range(1, 11)), device='mps', dtype=dtype)
2348+
y = torch.div(x, 101, rounding_mode=mode)
2349+
self.assertEqual(y.sum(), 0)
23512350

23522351
# See https://github.com/pytorch/pytorch/issues/82663
23532352
def test_bool_expand(self):

0 commit comments

Comments
 (0)