Skip to content

Commit 8dbe63c

Browse files
kulinsethpytorchmergebot
authored andcommitted
[MPS] Casting int64 to int32 for reduction ops and raise warning. (#94484)
Currently casting it as a workaround till we have full support in OS. Fixes ##88319 (comment) Pull Request resolved: #94484 Approved by: https://github.com/razarmehr
1 parent 715f373 commit 8dbe63c

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

aten/src/ATen/native/mps/operations/ReduceOps.mm

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,7 +1251,7 @@ Tensor std_mps(
12511251
(const Tensor& input_t,
12521252
MPSReductionType reduction_type,
12531253
const std::string& func_name) {
1254-
TORCH_CHECK(input_t.scalar_type() != ScalarType::Long, "MPS does not support min/max ops with int64 input");
1254+
TORCH_WARN_ONCE(input_t.scalar_type() != ScalarType::Long, "MPS: no support for int64 min/max ops, casting it to int32");
12551255

12561256
using CachedGraph = MPSUnaryCachedGraph;
12571257

@@ -1280,6 +1280,7 @@ Tensor std_mps(
12801280

12811281
MPSGraphTensor* outputTensor = nil;
12821282
MPSGraphTensor* castInputTensor = nil;
1283+
MPSGraphTensor* castOutputTensor = nil;
12831284

12841285
if (input_t.scalar_type() != ScalarType::Float &&
12851286
input_t.scalar_type() != ScalarType::Int &&
@@ -1302,8 +1303,15 @@ Tensor std_mps(
13021303
name:nil];
13031304
}
13041305

1306+
if(input_t.scalar_type() == ScalarType::Long) {
1307+
castOutputTensor = [mpsGraph castTensor:outputTensor
1308+
toType:MPSDataTypeInt64
1309+
name:@"castInputTensor"];
1310+
} else {
1311+
castOutputTensor = outputTensor;
1312+
}
13051313
newCachedGraph->inputTensor_ = inputTensor;
1306-
newCachedGraph->outputTensor_ = outputTensor;
1314+
newCachedGraph->outputTensor_ = castOutputTensor;
13071315
}
13081316
return newCachedGraph;
13091317
});

0 commit comments

Comments
 (0)