|
6 | 6 | # This import monkey-patches graph manipulation methods on Graph, used for the |
7 | 7 | # ONNX symbolics |
8 | 8 | import torch.onnx.utils |
| 9 | +from sys import maxsize |
9 | 10 |
|
10 | 11 | import torch.onnx.symbolic_helper as sym_help |
11 | 12 | from torch.onnx.symbolic_helper import parse_args, _unimplemented |
@@ -179,6 +180,89 @@ def flip(g, input, dims): |
179 | 180 | def fmod(g, input, other): |
180 | 181 | return g.op("Mod", input, other, fmod_i=1) |
181 | 182 |
|
| 183 | + |
| 184 | +@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i') |
| 185 | +def embedding_bag(g, |
| 186 | + embedding_matrix, |
| 187 | + indices, |
| 188 | + offsets, |
| 189 | + scale_grad_by_freq, |
| 190 | + mode, |
| 191 | + sparse, |
| 192 | + per_sample_weights, |
| 193 | + include_last_offset): |
| 194 | + if scale_grad_by_freq and sym_help._training_mode: |
| 195 | + return sym_help._onnx_unsupported('embedding_bag with scale_grad_by_freq for training mode') |
| 196 | + |
| 197 | + from torch.onnx.symbolic_opset9 import size, div, select |
| 198 | + |
| 199 | + # Check if initial indices was 2D. In functional.py: |
| 200 | + # offsets is set to torch.arange(0, indices.numel(), indices.size(1)) |
| 201 | + # Then indices is reshaped to 1D: indices.reshape(-1) |
| 202 | + if len(list(indices.node().inputs())) > 0 and indices.node().inputs().__next__().type().sizes() is not None \ |
| 203 | + and len(indices.node().inputs().__next__().type().sizes()) == 2: |
| 204 | + # Assert include_last_offset is False |
| 205 | + assert not include_last_offset |
| 206 | + embeddings = g.op("Gather", embedding_matrix, indices) |
| 207 | + dim_0 = size(g, offsets, g.op("Constant", value_t=torch.LongTensor([0]))) |
| 208 | + dim_1 = div(g, size(g, indices, g.op("Constant", value_t=torch.LongTensor([0]))), dim_0) |
| 209 | + dim_2 = g.op("Constant", value_t=torch.LongTensor([-1])) |
| 210 | + |
| 211 | + shape = [dim_0, dim_1, dim_2] |
| 212 | + shape = g.op("Concat", *shape, axis_i=0) |
| 213 | + |
| 214 | + if not sym_help._is_none(per_sample_weights): |
| 215 | + per_sample_weights = g.op("Unsqueeze", per_sample_weights, axes_i=[1]) |
| 216 | + embeddings = g.op("Mul", embeddings, per_sample_weights) |
| 217 | + |
| 218 | + embeddings = g.op("Reshape", embeddings, shape) |
| 219 | + if mode == 0: |
| 220 | + embeddings = g.op("ReduceSum", embeddings, axes_i=[1], keepdims_i=0) |
| 221 | + elif mode == 1: |
| 222 | + embeddings = g.op("ReduceMean", embeddings, axes_i=[1], keepdims_i=0) |
| 223 | + else: |
| 224 | + embeddings = g.op("ReduceMax", embeddings, axes_i=[1], keepdims_i=0) |
| 225 | + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. |
| 226 | + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. |
| 227 | + return embeddings, None, None, None |
| 228 | + elif offsets.type().sizes() is not None: |
| 229 | + if include_last_offset: |
| 230 | + offset_len = offsets.type().sizes()[0] - 1 |
| 231 | + offsets_extended = offsets |
| 232 | + else: |
| 233 | + offset_len = offsets.type().sizes()[0] |
| 234 | + offsets_extended = [offsets, g.op("Constant", value_t=torch.tensor([maxsize]))] |
| 235 | + offsets_extended = g.op("Concat", *offsets_extended, axis_i=0) |
| 236 | + list_ = [] |
| 237 | + for i in range(offset_len): |
| 238 | + start_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i)), axes_i=[0]) |
| 239 | + end_ = g.op("Unsqueeze", select(g, offsets_extended, torch.tensor(0), torch.tensor(i + 1)), axes_i=[0]) |
| 240 | + axes_ = g.op("Constant", value_t=torch.tensor([0])) |
| 241 | + indices_row = g.op("Slice", indices, start_, end_, axes_) |
| 242 | + |
| 243 | + embeddings = g.op("Gather", embedding_matrix, indices_row) |
| 244 | + if not sym_help._is_none(per_sample_weights): |
| 245 | + per_sample_weights_row = g.op("Slice", per_sample_weights, start_, end_, axes_) |
| 246 | + per_sample_weights_row = g.op("Unsqueeze", per_sample_weights_row, axes_i=[1]) |
| 247 | + embeddings = g.op("Mul", embeddings, per_sample_weights_row) |
| 248 | + if mode == 0: |
| 249 | + embeddings = g.op("ReduceSum", embeddings, axes_i=[0], keepdims_i=0) |
| 250 | + elif mode == 1: |
| 251 | + embeddings = g.op("ReduceMean", embeddings, axes_i=[0], keepdims_i=0) |
| 252 | + else: |
| 253 | + embeddings = g.op("ReduceMax", embeddings, axes_i=[0], keepdims_i=0) |
| 254 | + |
| 255 | + embeddings = g.op("Unsqueeze", embeddings, axes_i=[0]) |
| 256 | + list_.append(embeddings) |
| 257 | + |
| 258 | + output = g.op("Concat", *list_, axis_i=0) |
| 259 | + # aten::embedding_bag returns a tuple of 4 elements: output, offset2bag, bag_size, max_indices. |
| 260 | + # But the last three outputs are not used in torch.nn.EmbeddingBag or torch.nn.functional.embedding_bag. |
| 261 | + return output, None, None, None |
| 262 | + else: |
| 263 | + return sym_help._onnx_unsupported('embedding_bag with unknown shape of indices') |
| 264 | + |
| 265 | + |
182 | 266 | @parse_args('v', 't', 'i', 'i', 'i') |
183 | 267 | def fake_quantize_per_tensor_affine(g, inputs, scale, zero_point, quant_min=-128, quant_max=127): |
184 | 268 | if quant_min not in [0, -128] or quant_max not in [127, 255]: |
|
0 commit comments