Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/onnx/test_pytorch_onnx_caffe2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,22 @@ def forward(self, input):
x = torch.randn(1, 2, 3, 4, requires_grad=True)
self.run_model_test(FlattenModel(), train=False, input=x, batch_size=BATCH_SIZE)

def test_max(self):
class MaxModel(torch.nn.Module):
def forward(self, input):
return torch.max(input, dim=1)

x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(MaxModel(), train=False, input=x, batch_size=BATCH_SIZE)

def test_min(self):
class MinModel(torch.nn.Module):
def forward(self, input):
return torch.min(input, dim=1)

x = torch.randn(4, 4, requires_grad=True)
self.run_model_test(MinModel(), train=False, input=x, batch_size=BATCH_SIZE)

def test_argmax(self):
class ArgmaxModel(torch.nn.Module):
def forward(self, input):
Expand Down
26 changes: 12 additions & 14 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,37 +1117,35 @@ def clamp_max(g, self, max):
# torch.max (same for torch.min) actually has two interfaces smashed together:
# torch.max(x, dim, keepdim) and torch.max(x, y)
def max(g, self, dim_or_y=None, keepdim=None):
# torch.max(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMax", self, keepdims_i=0)
# torch.max(input, other)
if keepdim is None:
return g.op("Max", self, dim_or_y)
# torch.max(input, dim, keepdim)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="max",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
max = g.op("ReduceMax", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op('ArgMax', self, axis_i=dim, keepdims_i=keepdim)
return max, indices


def min(g, self, dim_or_y=None, keepdim=None):
# torch.min(input)
if dim_or_y is None and keepdim is None:
return g.op("ReduceMin", self, keepdims_i=0)
# torch.min(input, other)
if keepdim is None:
return g.op("Min", self, dim_or_y)
# torch.min(input, dim, keepdim)
else:
dim = _get_const(dim_or_y, 'i', 'dim')
keepdim = _get_const(keepdim, 'i', 'keepdim')
# TODO: export it as ReduceMax
return g.op("ATen",
self,
operator_s="min",
dim_i=dim,
keepdim_i=keepdim,
outputs=2)
min = g.op("ReduceMin", self, axes_i=[dim], keepdims_i=keepdim)
indices = g.op('ArgMin', self, axis_i=dim, keepdims_i=keepdim)
return min, indices


def exp(g, self):
Expand Down