Skip to content

Commit 63a55d4

Browse files
BowenBaofacebook-github-bot
authored andcommitted
Support gather export with OneHot + Mul (#21235)
Summary: This could serve as a alternative solution to export ```torch.gather``` before something similar goes into ONNX spec. The exported model is verified to be correct against onnxruntime backend. We weren't able to test against Caffe2 backend because it doesn't seem to support OneHot opset9. Pull Request resolved: #21235 Differential Revision: D15613039 Pulled By: houseroad fbshipit-source-id: 7fc097f85235c071474730233ede7d83074c347f
1 parent 240d62f commit 63a55d4

File tree

3 files changed

+178
-0
lines changed

3 files changed

+178
-0
lines changed
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
ir_version: 4
2+
producer_name: "pytorch"
3+
producer_version: "1.1"
4+
graph {
5+
node {
6+
output: "2"
7+
op_type: "Constant"
8+
attribute {
9+
name: "value"
10+
t {
11+
dims: 2
12+
data_type: 7
13+
raw_data: "\000\000\000\000\000\000\000\000\001\000\000\000\000\000\000\000"
14+
}
15+
type: TENSOR
16+
}
17+
}
18+
node {
19+
output: "3"
20+
op_type: "Constant"
21+
attribute {
22+
name: "value"
23+
t {
24+
dims: 1
25+
data_type: 7
26+
raw_data: "\001\000\000\000\000\000\000\000"
27+
}
28+
type: TENSOR
29+
}
30+
}
31+
node {
32+
input: "0"
33+
output: "4"
34+
op_type: "Shape"
35+
}
36+
node {
37+
input: "4"
38+
input: "3"
39+
output: "5"
40+
op_type: "Gather"
41+
attribute {
42+
name: "axis"
43+
i: 0
44+
type: INT
45+
}
46+
}
47+
node {
48+
input: "1"
49+
input: "5"
50+
input: "2"
51+
output: "6"
52+
op_type: "OneHot"
53+
attribute {
54+
name: "axis"
55+
i: 1
56+
type: INT
57+
}
58+
}
59+
node {
60+
input: "6"
61+
output: "7"
62+
op_type: "Cast"
63+
attribute {
64+
name: "to"
65+
i: 1
66+
type: INT
67+
}
68+
}
69+
node {
70+
input: "0"
71+
output: "8"
72+
op_type: "Unsqueeze"
73+
attribute {
74+
name: "axes"
75+
ints: 2
76+
type: INTS
77+
}
78+
}
79+
node {
80+
input: "8"
81+
input: "7"
82+
output: "9"
83+
op_type: "Mul"
84+
}
85+
node {
86+
input: "9"
87+
output: "10"
88+
op_type: "ReduceSum"
89+
attribute {
90+
name: "axes"
91+
ints: 1
92+
type: INTS
93+
}
94+
attribute {
95+
name: "keepdims"
96+
i: 0
97+
type: INT
98+
}
99+
}
100+
name: "torch-jit-export"
101+
input {
102+
name: "0"
103+
type {
104+
tensor_type {
105+
elem_type: 1
106+
shape {
107+
dim {
108+
dim_value: 3
109+
}
110+
dim {
111+
dim_value: 4
112+
}
113+
dim {
114+
dim_value: 3
115+
}
116+
}
117+
}
118+
}
119+
}
120+
input {
121+
name: "1"
122+
type {
123+
tensor_type {
124+
elem_type: 7
125+
shape {
126+
dim {
127+
dim_value: 3
128+
}
129+
dim {
130+
dim_value: 2
131+
}
132+
dim {
133+
dim_value: 3
134+
}
135+
}
136+
}
137+
}
138+
}
139+
output {
140+
name: "10"
141+
type {
142+
tensor_type {
143+
elem_type: 1
144+
shape {
145+
dim {
146+
dim_value: 3
147+
}
148+
dim {
149+
dim_value: 2
150+
}
151+
dim {
152+
dim_value: 3
153+
}
154+
}
155+
}
156+
}
157+
}
158+
}
159+
opset_import {
160+
version: 9
161+
}

test/onnx/test_operators.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,11 @@ def test_nonzero(self):
586586
x = torch.tensor([[[2., 2.], [1., 0.]], [[0., 0.], [1., 1.]]], requires_grad=True)
587587
self.assertONNX(lambda x: torch.nonzero(x), x)
588588

589+
def test_gather(self):
590+
data = torch.randn(3, 4, 3, requires_grad=True)
591+
index = torch.tensor([2, 0]).view(1, 2, 1).expand(3, 2, 3)
592+
self.assertONNX(lambda data, index: data.gather(1, index), (data, index))
593+
589594
def test_master_opset(self):
590595
x = torch.randn(2, 3).float()
591596
y = torch.randn(2, 3).float()

torch/onnx/symbolic_opset9.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,3 +1600,15 @@ def log2(g, self):
16001600

16011601
def prim_shape(g, self):
16021602
return g.op('Shape', self)
1603+
1604+
1605+
@parse_args('v', 'i', 'v', 'v')
1606+
def gather(g, self, dim, index, sparse_grad=False):
1607+
# NOTE: Update this workaround if ONNX has native Gather support.
1608+
# The current Gather in ONNX is not the same as torch.gather.
1609+
dtype = self.type().scalarType()
1610+
values = g.op("Constant", value_t=torch.LongTensor([0, 1]))
1611+
depth = size(g, self, g.op("Constant", value_t=torch.LongTensor([dim])))
1612+
index = g.op("Cast", g.op("OneHot", index, depth, values, axis_i=dim), to_i=sym_help.cast_pytorch_to_onnx[dtype])
1613+
mul = g.op("Mul", g.op("Unsqueeze", self, axes_i=[dim + 1]), index)
1614+
return g.op("ReduceSum", mul, axes_i=[dim], keepdims_i=0)

0 commit comments

Comments
 (0)