|
| 1 | +import numbers |
| 2 | + |
1 | 3 | import torch |
2 | 4 | from torch.nn.modules.utils import _single, _pair, _triple |
3 | 5 | from torch.nn.utils.rnn import PackedSequence |
@@ -562,12 +564,42 @@ def clamp(g, self, min, max): |
562 | 564 | return g.op("Clip", self, min_f=min, max_f=max) |
563 | 565 |
|
564 | 566 |
|
565 | | -def max(g, self, other): |
566 | | - return g.op("Max", self, other) |
567 | | - |
568 | | - |
569 | | -def min(g, self, other): |
570 | | - return g.op("Min", self, other) |
| 567 | +# torch.max (same for torch.min) actually has two interfaces smashed together: |
| 568 | +# torch.max(x, dim, keepdim) and torch.max(x, y) |
| 569 | +def max(g, self, *args, **kwargs): |
| 570 | + dim = kwargs.get("dim", None) |
| 571 | + if dim is None and isinstance(args[0], numbers.Number): |
| 572 | + dim = args[0] |
| 573 | + if dim is not None: |
| 574 | + keepdim = kwargs.get("keepdim", False) |
| 575 | + # TODO: export it as ReduceMax |
| 576 | + return g.op("ATen", |
| 577 | + self, |
| 578 | + operator_s="max", |
| 579 | + dim_i=dim, |
| 580 | + keepdim_i=keepdim, |
| 581 | + outputs=2) |
| 582 | + else: |
| 583 | + (other,) = args |
| 584 | + return g.op("Max", self, other) |
| 585 | + |
| 586 | + |
| 587 | +def min(g, self, *args, **kwargs): |
| 588 | + dim = kwargs.get("dim", None) |
| 589 | + if dim is None and isinstance(args[0], numbers.Number): |
| 590 | + dim = args[0] |
| 591 | + if dim is not None: |
| 592 | + keepdim = kwargs.get("keepdim", False) |
| 593 | + # TODO: export it as ReduceMin |
| 594 | + return g.op("ATen", |
| 595 | + self, |
| 596 | + operator_s="min", |
| 597 | + dim_i=dim, |
| 598 | + keepdim_i=keepdim, |
| 599 | + outputs=2) |
| 600 | + else: |
| 601 | + (other,) = args |
| 602 | + return g.op("Min", self, other) |
571 | 603 |
|
572 | 604 |
|
573 | 605 | def eq(g, self, other): |
|
0 commit comments