Skip to content

Commit ba64724

Browse files
James Reedezyang
authored andcommitted
Softmax symbolic should account for negative dim (#5846)
1 parent 22ef8e5 commit ba64724

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch/onnx/symbolic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,9 @@ def softmax(g, input, dim=None):
356356
# [0.167, 0.167, 0.167]]
357357
# So only when dim and axis both equal to ndim - 1 (the last dimension),
358358
# their semantics are equivalent.
359-
if len(input.type().sizes()) != dim + 1:
359+
if dim < 0:
360+
check_dim = len(input.type().sizes()) + dim
361+
if len(input.type().sizes()) != check_dim + 1:
360362
return _unimplemented("dim", "ONNX and PyTorch use different strategies to split the input.")
361363
return g.op('Softmax', input, axis_i=dim)
362364

0 commit comments

Comments
 (0)