Skip to content

Commit 4e4d4b0

Browse files
authored
[MPS] Add TORCH_CHECK for Convolution (#95495)
* Raise errors for Conv and remove FFTs from Fallback list. * Move the FFT to a separate commit.
1 parent c4fa850 commit 4e4d4b0

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

aten/src/ATen/native/mps/operations/Convolution.mm

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ Tensor _mps_convolution_impl(
6666
int64_t groups,
6767
c10::optional<IntArrayRef> input_shape) {
6868
TORCH_CHECK(input_t.dim() < 5, "Conv3D is not supported on MPS");
69+
TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
6970

7071
namespace native_mps = at::native::mps;
7172
CheckedFrom c = "mps_convolution";
@@ -256,6 +257,7 @@ Tensor mps_convolution_backward_input(
256257
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
257258
namespace native_mps = at::native::mps;
258259
using namespace mps;
260+
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
259261
CheckedFrom c = "mps_convolution_backward_input";
260262
TensorArg grad_output{ grad_output_t, "grad_output", 1 },
261263
weight{ weight_t, "weight", 2 };
@@ -392,6 +394,7 @@ Tensor mps_convolution_backward_weights(
392394
IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
393395
namespace native_mps = at::native::mps;
394396
using namespace mps;
397+
TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
395398
CheckedFrom c = "mps_convolution_backward_weights";
396399
auto memory_format = grad_output_t.suggest_memory_format();
397400
bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);

0 commit comments

Comments
 (0)