Skip to content

Commit 40096c9

Browse files
bddppqezyang
authored andcommitted
Support export torch.max(input, dim) and torch.min(input, dim) to ONNX (#6220)
* Support export torch.max(input, dim) and torch.min(input, dim) to ONNX * .
1 parent 8392639 commit 40096c9

File tree

1 file changed

+38
-6
lines changed

1 file changed

+38
-6
lines changed

torch/onnx/symbolic.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numbers
2+
13
import torch
24
from torch.nn.modules.utils import _single, _pair, _triple
35
from torch.nn.utils.rnn import PackedSequence
@@ -562,12 +564,42 @@ def clamp(g, self, min, max):
562564
return g.op("Clip", self, min_f=min, max_f=max)
563565

564566

565-
def max(g, self, other):
566-
return g.op("Max", self, other)
567-
568-
569-
def min(g, self, other):
570-
return g.op("Min", self, other)
567+
# torch.max (same for torch.min) actually has two interfaces smashed together:
568+
# torch.max(x, dim, keepdim) and torch.max(x, y)
569+
def max(g, self, *args, **kwargs):
570+
dim = kwargs.get("dim", None)
571+
if dim is None and isinstance(args[0], numbers.Number):
572+
dim = args[0]
573+
if dim is not None:
574+
keepdim = kwargs.get("keepdim", False)
575+
# TODO: export it as ReduceMax
576+
return g.op("ATen",
577+
self,
578+
operator_s="max",
579+
dim_i=dim,
580+
keepdim_i=keepdim,
581+
outputs=2)
582+
else:
583+
(other,) = args
584+
return g.op("Max", self, other)
585+
586+
587+
def min(g, self, *args, **kwargs):
588+
dim = kwargs.get("dim", None)
589+
if dim is None and isinstance(args[0], numbers.Number):
590+
dim = args[0]
591+
if dim is not None:
592+
keepdim = kwargs.get("keepdim", False)
593+
# TODO: export it as ReduceMin
594+
return g.op("ATen",
595+
self,
596+
operator_s="min",
597+
dim_i=dim,
598+
keepdim_i=keepdim,
599+
outputs=2)
600+
else:
601+
(other,) = args
602+
return g.op("Min", self, other)
571603

572604

573605
def eq(g, self, other):

0 commit comments

Comments
 (0)