Skip to content

Commit 346c69a

Browse files
neginraooffacebook-github-bot
authored andcommitted
[ONNX] Export embedding_bag (#41234)
Summary: Enable export of embedding_bag op to ONNX Pull Request resolved: #41234 Reviewed By: houseroad Differential Revision: D22567470 Pulled By: bzinodev fbshipit-source-id: 2fcf74e54f3a9dee4588d7877a4ac9eb6c2a3629
1 parent 7eb71b4 commit 346c69a

File tree

2 files changed

+147
-0
lines changed

2 files changed

+147
-0
lines changed

test/onnx/test_pytorch_onnx_onnxruntime.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2894,6 +2894,69 @@ def forward(self, input):
28942894
x = torch.tensor([False, True, True])
28952895
self.run_test(model, x)
28962896

2897+
@unittest.skip("Enable once jit trace Tensor.numel as constant is fixed.")
2898+
def test_embedding_bag_dynamic(self):
2899+
class EmbeddingModel(torch.nn.Module):
2900+
def __init__(self):
2901+
super().__init__()
2902+
self.embeddingbag = torch.nn.EmbeddingBag(40, 12, mode='sum')
2903+
2904+
def forward(self, input):
2905+
return self.embeddingbag(input)
2906+
2907+
model = EmbeddingModel()
2908+
x = torch.randint(7, (10, 5))
2909+
y = torch.randint(10, (20, 5))
2910+
self.run_test(model, x, test_with_inputs=[y],
2911+
input_names=['input'],
2912+
output_names=['output'],
2913+
dynamic_axes={'input': [0],
2914+
'output': [0]
2915+
})
2916+
2917+
@skipIfUnsupportedMinOpsetVersion(10)
2918+
def test_embedding_bag(self):
2919+
model = torch.nn.EmbeddingBag(10, 5, mode='sum', scale_grad_by_freq=True)
2920+
input = torch.randint(10, (7,))
2921+
offset = torch.tensor([0, 2, 5, 6])
2922+
self.run_test(model, (input, offset))
2923+
2924+
model = torch.nn.EmbeddingBag(10, 5, mode='sum', include_last_offset=True)
2925+
input = torch.randint(10, (7,))
2926+
offset = torch.tensor([0, 2, 5, 6])
2927+
self.run_test(model, (input, offset))
2928+
2929+
model = torch.nn.EmbeddingBag(10, 5, mode='max')
2930+
input = torch.randint(10, (7, 5))
2931+
self.run_test(model, (input))
2932+
2933+
@skipIfUnsupportedMinOpsetVersion(10)
2934+
def test_embedding_bag_1d_per_sample_weights(self):
2935+
class EmbeddingModel(torch.nn.Module):
2936+
def forward(self, embedding_matrix, input, offset, weights):
2937+
return torch.nn.functional.embedding_bag(embedding_matrix, input, offsets=offset,
2938+
mode='sum', per_sample_weights=weights)
2939+
2940+
model = EmbeddingModel()
2941+
x = torch.randint(7, (6,))
2942+
w = torch.randn(6,)
2943+
offset = torch.tensor([0, 2, 5])
2944+
embedding_matrix = torch.rand(10, 15)
2945+
self.run_test(model, (embedding_matrix, x, offset, w))
2946+
2947+
@skipIfUnsupportedMinOpsetVersion(10)
2948+
def test_embedding_bag_2d_per_sample_weights(self):
2949+
class EmbeddingModel(torch.nn.Module):
2950+
def forward(self, embedding_matrix, input, weights):
2951+
return torch.nn.functional.embedding_bag(embedding_matrix, input,
2952+
mode='sum', per_sample_weights=weights)
2953+
2954+
embedding_matrix = torch.rand(10, 15)
2955+
model = EmbeddingModel()
2956+
x = torch.randint(7, (2, 3))
2957+
w = torch.randn(2, 3)
2958+
self.run_test(model, (embedding_matrix, x, w))
2959+
28972960
@skipIfUnsupportedMinOpsetVersion(8)
28982961
def test_meshgrid(self):
28992962
class Meshgrid(torch.nn.Module):

torch/onnx/symbolic_opset10.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# This import monkey-patches graph manipulation methods on Graph, used for the
77
# ONNX symbolics
88
import torch.onnx.utils
9+
from sys import maxsize
910

1011
import torch.onnx.symbolic_helper as sym_help
1112
from torch.onnx.symbolic_helper import parse_args, _unimplemented
@@ -179,6 +180,89 @@ def flip(g, input, dims):
179180
def fmod(g, input, other):
180181
return g.op("Mod", input, other, fmod_i=1)
181182

183+
184+
@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i')
185+
def embedding_bag(g,
186+
embedding_matrix,
187+
indices,
188+
offsets,
189+
scale_grad_by_freq,
190+
mode,
191+
sparse,
192+
per_sample_weights,
193+
include_last_offset):
194+
if scale_grad_by_freq and sym_help._training_mode:
195+
return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode')
196+
197+
from torch.onnx.symbolic_opset9 import size, div, select
198+
199+
# Check if initial indices was 2D. In functional.py:
200+
# offsets is set to torch.arange(0, indices.numel(), indices.size(1))
201+
# Then indices is reshaped to 1D: indices.reshape(-1)
202+
if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \
203+
and len(indices.node().inputs().__next__().type().sizes()) == 2:
204+
# Assert include_last_offset is False
205+
assert not include_last_offset
206+
embeddings = g.op("Gather", embedding_matrix, indices)
207+
dim_0 = size(g, offsets, g.op("Constant", value_t=torch.LongTensor([0])))
208+
dim_1 = div(g, size(g, indices, g.op("Constant", value_t=torch.LongTensor([0]))), dim_0)
209+
dim_2 = g.op("Constant", value_t=torch.LongTensor([-1]))
210+
211+
shape = [dim_0, dim_1, dim_2]
212+
shape = g.op("Concat", *shape, axis_i=0)
213+
214+
if not sym_help._is_none(per_sample_weights):
215+
per_sample_weights = g.op("Unsqueeze", per_sample_weights, axes_i=[1])
216+
embeddings = g.op("Mul", embeddings, per_sample_weights)
217+
218+
embeddings = g.op("Reshape", embeddings, shape)
219+
if mode == 0:
220+
embeddings = g.op("ReduceSum", embeddings, axes_i=[1], keepdims_i=0)
221+
elif mode == 1:
222+
embeddings = g.op("ReduceMean", embeddings, axes_i=[1], keepdims_i=0)
223+
else:
224+
embeddings = g.op("ReduceMax", embeddings, axes_i=[1], keepdims_i=0)
225+
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
226+
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
227+
return embeddings, None, None, None
228+
elif offsets.type().sizes() is not None:
229+
if include_last_offset:
230+
offset_len = offsets.type().sizes()[0] - 1
231+
offsets_extended = offsets
232+
else:
233+
offset_len = offsets.type().sizes()[0]
234+
offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))]
235+
offsets_extended = g.op("Concat", *offsets_extended, axis_i=0)
236+
list_ = []
237+
for i in range(offset_len):
238+
start_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), axes_i=[0])
239+
end_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), axes_i=[0])
240+
axes_ = g.op("Constant", value_t=torch.tensor([0]))
241+
indices_row = g.op("Slice", indices, start_, end_, axes_)
242+
243+
embeddings = g.op("Gather", embedding_matrix, indices_row)
244+
if not sym_help._is_none(per_sample_weights):
245+
per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_)
246+
per_sample_weights_row = g.op("Unsqueeze", per_sample_weights_row, axes_i=[1])
247+
embeddings = g.op("Mul", embeddings, per_sample_weights_row)
248+
if mode == 0:
249+
embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0)
250+
elif mode == 1:
251+
embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0)
252+
else:
253+
embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0)
254+
255+
embeddings = g.op("Unsqueeze", embeddings, axes_i=[0])
256+
list_.append(embeddings)
257+
258+
output = g.op("Concat", *list_, axis_i=0)
259+
# aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices.
260+
# But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag.
261+
return output, None, None, None
262+
else:
263+
return sym_help._onnx_unsupported('embedding_bag with unknown shape of indices')
264+
265+
182266
@parse_args('v', 't', 'i', 'i', 'i')
183267
def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127):
184268
if quant_min not in [0, -128] or quant_max not in [127, 255]:

0 commit comments

Comments
 (0)