@@ -66,6 +66,7 @@ Tensor _mps_convolution_impl(
6666 int64_t groups,
6767 c10::optional<IntArrayRef> input_shape) {
6868 TORCH_CHECK (input_t .dim () < 5 , " Conv3D is not supported on MPS" );
69+ TORCH_CHECK (isFloatingType (input_t .scalar_type ()), " Convolution is supported only for Floating types" );
6970
7071 namespace native_mps = at::native::mps;
7172 CheckedFrom c = " mps_convolution" ;
@@ -256,6 +257,7 @@ Tensor mps_convolution_backward_input(
256257 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
257258 namespace native_mps = at::native::mps;
258259 using namespace mps ;
260+ TORCH_CHECK (isFloatingType (grad_output_t .scalar_type ()), " Convolution is supported only for Floating types" );
259261 CheckedFrom c = " mps_convolution_backward_input" ;
260262 TensorArg grad_output{ grad_output_t , " grad_output" , 1 },
261263 weight{ weight_t , " weight" , 2 };
@@ -392,6 +394,7 @@ Tensor mps_convolution_backward_weights(
392394 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool bias_defined) {
393395 namespace native_mps = at::native::mps;
394396 using namespace mps ;
397+ TORCH_CHECK (isFloatingType (grad_output_t .scalar_type ()), " Convolution is supported only for Floating types" );
395398 CheckedFrom c = " mps_convolution_backward_weights" ;
396399 auto memory_format = grad_output_t .suggest_memory_format ();
397400 bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast);
0 commit comments