Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torch/onnx/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,28 @@ def avg_pool3d(g, input, kernel_size, stride, padding, ceil_mode, count_include_
pads_i=_triple(padding))


def reflection_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "reflect"
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)


def replication_pad(g, input, padding):
from torch.autograd._functions.utils import prepare_onnx_paddings
mode = "edge"
paddings = prepare_onnx_paddings(len(input.type().sizes()), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)


reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad


def log_softmax(g, input, dim=None):
return g.op("Log", g.op('Softmax', input, axis_i=dim).setTypeAs(input))

Expand Down