Skip to content

Commit dde10e1

Browse files
houseroadsoumith
authored andcommitted
Add docs talking about how to adding symbolic for unsupported ops (#3741)
1 parent 7874f61 commit dde10e1

File tree

1 file changed

+136
-3
lines changed

1 file changed

+136
-3
lines changed

docs/source/onnx.rst

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ The following operators are supported:
148148
* log_softmax
149149
* unfold (experimental support with ATen-Caffe2 integration)
150150
* elu
151+
* concat
152+
* abs
153+
* index_select
154+
* pow
155+
* clamp
156+
* max
157+
* min
158+
* eq
159+
* exp
160+
* permute
151161
* Conv
152162
* BatchNorm
153163
* MaxPool1d (ceil_mode not supported)
@@ -159,7 +169,6 @@ The following operators are supported:
159169
* Dropout
160170
* FeatureDropout (training mode not supported)
161171
* Index (constant integer and tuple indices supported)
162-
* Negate
163172

164173
The operator set above is sufficient to export the following models:
165174

@@ -173,8 +182,132 @@ The operator set above is sufficient to export the following models:
173182
* VGG
174183
* `word_language_model <https://github.com/pytorch/examples/tree/master/word_language_model>`_
175184

176-
The interface for specifying operator definitions is highly experimental
177-
and undocumented; adventurous users should note that the APIs will probably
185+
Adding export support for operators is an *advance usage*.
186+
To achieve this, developers need to touch the source code of PyTorch.
187+
Please follow the `instructions <https://github.com/pytorch/pytorch#from-source>`_
188+
for installing PyTorch from source.
189+
If the wanted operator is standardized in ONNX, it should be easy to add
190+
support for exporting such operator (adding a symbolic function for the operator).
191+
To confirm whether the operator is standardized or not, please check the
192+
`ONNX operator list <http://https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_.
193+
194+
If the operator is an ATen operator, which means you can find the declaration
195+
of the function in ``torch/csrc/autograd/generated/VariableType.h``
196+
(available in generated code in PyTorch install dir), you should add the symbolic
197+
function in ``torch/onnx/symbolic.py`` and follow the instructions listed as below:
198+
199+
* Define the symbolic function in
200+
`torch/onnx/symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_.
201+
Make sure the function has the same name as the ATen operator/function
202+
defined in ``VariableType.h``.
203+
* The first parameter is always the exported ONNX graph.
204+
Parameter names must EXACTLY match the names in ``VariableType.h``,
205+
because dispatch is done with keyword arguments.
206+
* Parameter ordering does NOT necessarily match what is in ``VariableType.h``,
207+
tensors (inputs) are always first, then non-tensor arguments.
208+
* In the symbolic function, if the operator is already standardized in ONNX,
209+
we only need to create a node to represent the ONNX operator in the graph.
210+
* If the input argument is a tensor, but ONNX asks for a scalar, we have to
211+
explicitly do the conversion. The helper function ``_scalar`` can convert a
212+
scalar tensor into a python scalar, and ``_if_scalar_type_as`` can turn a
213+
Python scalar into a PyTorch tensor.
214+
215+
If the operator is a non-ATen operator, the symbolic function has to be
216+
added in the corresponding PyTorch Function class. Please read the following
217+
instructions:
218+
219+
* Create a symbolic function named ``symbolic`` in the corresponding Function class.
220+
* The first parameter is always the exported ONNX graph.
221+
* Parameter names except the first must EXACTLY match the names in ``forward``.
222+
* The output tuple size must match the outputs of ``forward``.
223+
* In the symbolic function, if the operator is already standardized in ONNX,
224+
we just need to create a node to represent the ONNX operator in the graph.
225+
226+
Symbolic functions should be implemented in Python. All of these functions interact
227+
with Python methods which are implemented via C++-Python bindings,
228+
but intuitively the interface they provide looks like this::
229+
230+
231+
def operator/symbolic(g, *inputs):
232+
"""
233+
Modifies Graph (e.g., using "op"), adding the ONNX operations representing
234+
this PyTorch function, and returning a Value or tuple of Values specifying the
235+
ONNX outputs whose values correspond to the original PyTorch return values
236+
of the autograd Function (or None if an output is not supported by ONNX).
237+
238+
Arguments:
239+
g (Graph): graph to write the ONNX representation into
240+
inputs (Value...): list of values representing the variables which contain
241+
the inputs for this function
242+
"""
243+
244+
class Value(object):
245+
"""Represents an intermediate tensor value computed in ONNX."""
246+
def type(self):
247+
"""Returns the Type of the value."""
248+
249+
class Type(object):
250+
def sizes(self):
251+
"""Returns a tuple of ints representing the shape of a tensor this describes."""
252+
253+
class Graph(object):
254+
def op(self, opname, *inputs, **attrs):
255+
"""
256+
Create an ONNX operator 'opname', taking 'args' as inputs
257+
and attributes 'kwargs' and add it as a node to the current graph,
258+
returning the value representing the single output of this
259+
operator (see the `outputs` keyword argument for multi-return
260+
nodes).
261+
262+
The set of operators and the inputs/attributes they take
263+
is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
264+
265+
Arguments:
266+
opname (string): The ONNX operator name, e.g., `Abs` or `Add`.
267+
args (Value...): The inputs to the operator; usually provided
268+
as arguments to the `symbolic` definition.
269+
kwargs: The attributes of the ONNX operator, with keys named
270+
according to the following convention: `alpha_f` indicates
271+
the `alpha` attribute with type `f`. The valid type specifiers are
272+
`f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute
273+
specified with type float accepts either a single float, or a
274+
list of floats (e.g., you would say `dims_i` for a `dims` attribute
275+
that takes a list of integers).
276+
outputs (int, optional): The number of outputs this operator returns;
277+
by default an operator is assumed to return a single output.
278+
If `outputs` is greater than one, this functions returns a tuple
279+
of output `Value`, representing each output of the ONNX operator
280+
in positional.
281+
"""
282+
283+
The ONNX graph C++ definition is in ``torch/csrc/jit/ir.h``.
284+
285+
Here is an example of handling missing symbolic function for ``elu`` operator.
286+
We try to export the model and see the error message as below::
287+
288+
UserWarning: ONNX export failed on elu because torch.onnx.symbolic.elu does not exist
289+
RuntimeError: ONNX export failed: Couldn't export operator elu
290+
291+
The export fails because PyTorch does not support exporting ``elu`` operator.
292+
We find ``virtual Tensor elu(const Tensor & input, Scalar alpha, bool inplace) const override;``
293+
in ``VariableType.h``. This means ``elu`` is an ATen operator.
294+
We check the `ONNX operator list <http://https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_,
295+
and confirm that ``Elu`` is standardized in ONNX.
296+
We add the following lines to ``symbolic.py``::
297+
298+
def elu(g, input, alpha, inplace=False):
299+
return g.op("Elu", input, alpha_f=_scalar(alpha))
300+
301+
Now PyTorch is able to export ``elu`` operator.
302+
303+
There are more examples in
304+
`symbolic.py <https://github.com/pytorch/pytorch/blob/master/torch/onnx/symbolic.py>`_,
305+
`tensor.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/autograd/_functions/tensor.py#L24>`_,
306+
`padding.py <https://github.com/pytorch/pytorch/blob/99037d627da68cdf53d3d0315deceddfadf03bba/torch/nn/_functions/padding.py#L8>`_.
307+
308+
309+
The interface for specifying operator definitions is experimental;
310+
adventurous users should note that the APIs will probably
178311
change in a future interface.
179312

180313
Functions

0 commit comments

Comments
 (0)