File tree Expand file tree Collapse file tree 1 file changed +10
-2
lines changed
aten/src/ATen/native/mps/operations Expand file tree Collapse file tree 1 file changed +10
-2
lines changed Original file line number Diff line number Diff line change @@ -1251,7 +1251,7 @@ Tensor std_mps(
12511251 (const Tensor& input_t ,
12521252 MPSReductionType reduction_type,
12531253 const std::string& func_name) {
1254- TORCH_CHECK (input_t .scalar_type () != ScalarType::Long, " MPS does not support min/max ops with int64 input " );
1254+ TORCH_WARN_ONCE (input_t .scalar_type () != ScalarType::Long, " MPS: no support for int64 min/max ops, casting it to int32 " );
12551255
12561256 using CachedGraph = MPSUnaryCachedGraph;
12571257
@@ -1280,6 +1280,7 @@ Tensor std_mps(
12801280
12811281 MPSGraphTensor* outputTensor = nil ;
12821282 MPSGraphTensor* castInputTensor = nil ;
1283+ MPSGraphTensor* castOutputTensor = nil ;
12831284
12841285 if (input_t .scalar_type () != ScalarType::Float &&
12851286 input_t .scalar_type () != ScalarType::Int &&
@@ -1302,8 +1303,15 @@ Tensor std_mps(
13021303 name: nil ];
13031304 }
13041305
1306+ if (input_t .scalar_type () == ScalarType::Long) {
1307+ castOutputTensor = [mpsGraph castTensor: outputTensor
1308+ toType: MPSDataTypeInt64
1309+ name: @" castInputTensor" ];
1310+ } else {
1311+ castOutputTensor = outputTensor;
1312+ }
13051313 newCachedGraph->inputTensor_ = inputTensor;
1306- newCachedGraph->outputTensor_ = outputTensor ;
1314+ newCachedGraph->outputTensor_ = castOutputTensor ;
13071315 }
13081316 return newCachedGraph;
13091317 });
You can’t perform that action at this time.
0 commit comments