@@ -406,6 +406,20 @@ def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
406406 return r , None
407407
408408
409+ def max_pool3d (g , input , kernel_size , stride , padding , dilation , ceil_mode ):
410+ if ceil_mode :
411+ return _unimplemented ("max_pool3d" , "ceil_mode" )
412+ if set (_triple (dilation )) != {1 }:
413+ return _unimplemented ("max_pool3d" , "dilation" )
414+ if not stride :
415+ stride = kernel_size
416+ r = g .op ("MaxPool" , input ,
417+ kernel_shape_i = _triple (kernel_size ),
418+ pads_i = _triple (padding ) * 2 ,
419+ strides_i = _triple (stride ))
420+ return r , None
421+
422+
409423def avg_pool2d (g , input , kernel_size , stride , padding , ceil_mode , count_include_pad ):
410424 if ceil_mode :
411425 return _unimplemented ("avg_pool2d" , "ceil_mode" )
@@ -539,7 +553,9 @@ def unfold(g, input, dimension, size, step):
539553 return g .op ("ATen" , input , operator_s = "unfold" , dimension_i = dimension , size_i = size , step_i = step )
540554
541555
542- def elu (g , input , alpha , inplace = False ):
556+ def elu (g , input , alpha , scale ):
557+ if scale and scale != 1. :
558+ return _unimplemented ("scale" , "does not support scale in Elu" )
543559 # See Note [Export inplace]
544560 return g .op ("Elu" , input , alpha_f = _scalar (alpha ))
545561
0 commit comments