Skip to content

Commit 7cbef70

Browse files
houseroadezyang
authored andcommitted
Fix the onnx symbolic for selu and maxpool3d (#6816)
1 parent 645ad7a commit 7cbef70

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torch/onnx/symbolic.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,20 @@ def max_pool2d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
406406
return r, None
407407

408408

409+
def max_pool3d(g, input, kernel_size, stride, padding, dilation, ceil_mode):
410+
if ceil_mode:
411+
return _unimplemented("max_pool3d", "ceil_mode")
412+
if set(_triple(dilation)) != {1}:
413+
return _unimplemented("max_pool3d", "dilation")
414+
if not stride:
415+
stride = kernel_size
416+
r = g.op("MaxPool", input,
417+
kernel_shape_i=_triple(kernel_size),
418+
pads_i=_triple(padding) * 2,
419+
strides_i=_triple(stride))
420+
return r, None
421+
422+
409423
def avg_pool2d(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad):
410424
if ceil_mode:
411425
return _unimplemented("avg_pool2d", "ceil_mode")
@@ -539,7 +553,9 @@ def unfold(g, input, dimension, size, step):
539553
return g.op("ATen", input, operator_s="unfold", dimension_i=dimension, size_i=size, step_i=step)
540554

541555

542-
def elu(g, input, alpha, inplace=False):
556+
def elu(g, input, alpha, scale):
557+
if scale and scale != 1.:
558+
return _unimplemented("scale", "does not support scale in Elu")
543559
# See Note [Export inplace]
544560
return g.op("Elu", input, alpha_f=_scalar(alpha))
545561

0 commit comments

Comments
 (0)