Skip to content

Commit c43c911

Browse files
authored
Export onnx protobuf bindings to python (#6651)
* Export onnx protobuf bindings to python * rename native onnx module to _onnx
1 parent f50f176 commit c43c911

File tree

6 files changed

+50
-8
lines changed

6 files changed

+50
-8
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,7 @@ def run(self):
686686
"torch/csrc/tensor/python_tensor.cpp",
687687
"torch/csrc/onnx/onnx.pb.cpp",
688688
"torch/csrc/onnx/onnx.cpp",
689+
"torch/csrc/onnx/init.cpp",
689690
]
690691

691692
try:

torch/csrc/Module.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "torch/csrc/jit/python_tracer.h"
3535
#include "torch/csrc/jit/init.h"
3636
#include "torch/csrc/jit/python_ir.h"
37+
#include "torch/csrc/onnx/init.h"
3738

3839
#ifdef WITH_CUDNN
3940
#include "cudnn.h"
@@ -478,6 +479,7 @@ static PyObject* initModule() {
478479
ASSERT_TRUE(THPEngine_initModule(module));
479480
torch::autograd::initAutogradClosureBindings(module);
480481
torch::jit::initJITBindings(module);
482+
torch::onnx::initONNXBindings(module);
481483
torch::autograd::initNNFunctions(module);
482484
torch::autograd::init_legacy_variable(module);
483485
#ifdef WITH_CUDA

torch/csrc/onnx/init.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "torch/csrc/onnx/init.h"
2+
#include "torch/csrc/onnx/onnx.pb.h"
3+
4+
namespace torch { namespace onnx {
5+
void initONNXBindings(PyObject* module) {
6+
auto m = py::handle(module).cast<py::module>();
7+
auto onnx = m.def_submodule("_onnx");
8+
py::enum_<onnx_TensorProto_DataType>(onnx, "TensorProtoDataType")
9+
.value("UNDEFINED", onnx_TensorProto_DataType_UNDEFINED)
10+
.value("FLOAT", onnx_TensorProto_DataType_FLOAT)
11+
.value("UINT8", onnx_TensorProto_DataType_UINT8)
12+
.value("INT8", onnx_TensorProto_DataType_INT8)
13+
.value("UINT16", onnx_TensorProto_DataType_UINT16)
14+
.value("INT16", onnx_TensorProto_DataType_INT16)
15+
.value("INT32", onnx_TensorProto_DataType_INT32)
16+
.value("INT64", onnx_TensorProto_DataType_INT64)
17+
.value("STRING", onnx_TensorProto_DataType_STRING)
18+
.value("BOOL", onnx_TensorProto_DataType_BOOL)
19+
.value("FLOAT16", onnx_TensorProto_DataType_FLOAT16)
20+
.value("DOUBLE", onnx_TensorProto_DataType_DOUBLE)
21+
.value("UINT32", onnx_TensorProto_DataType_UINT32)
22+
.value("UINT64", onnx_TensorProto_DataType_UINT64)
23+
.value("COMPLEX64", onnx_TensorProto_DataType_COMPLEX64)
24+
.value("COMPLEX128", onnx_TensorProto_DataType_COMPLEX128);
25+
}
26+
}} // namespace torch::onnx

torch/csrc/onnx/init.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include "torch/csrc/utils/pybind.h"
4+
5+
namespace torch { namespace onnx {
6+
7+
void initONNXBindings(PyObject* module);
8+
9+
}} // namespace torch::onnx

torch/onnx/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import functools
22
import types
33

4+
import torch._C as _C
5+
6+
TensorProtoDataType = _C._onnx.TensorProtoDataType
7+
48
ONNX_ARCHIVE_MODEL_PROTO_NAME = "__MODEL_PROTO"
59

610

torch/onnx/symbolic.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,14 +631,14 @@ def _unique(g, input, sorted, return_inverse):
631631
# TODO: remove these once we support Type's in the JIT IR and we can once again
632632
# use the unified toType operator
633633
cast_pytorch_to_onnx = {
634-
'uint8_t': 2,
635-
'int8_t': 3,
636-
'double': 11,
637-
'float': 1,
638-
'Half': 10,
639-
'int': 6,
640-
'int64_t': 7,
641-
'int16_t': 5,
634+
'uint8_t': torch.onnx.TensorProtoDataType.UINT8,
635+
'int8_t': torch.onnx.TensorProtoDataType.INT8,
636+
'double': torch.onnx.TensorProtoDataType.DOUBLE,
637+
'float': torch.onnx.TensorProtoDataType.FLOAT,
638+
'Half': torch.onnx.TensorProtoDataType.FLOAT16,
639+
'int': torch.onnx.TensorProtoDataType.INT32,
640+
'int64_t': torch.onnx.TensorProtoDataType.INT64,
641+
'int16_t': torch.onnx.TensorProtoDataType.INT16,
642642
}
643643

644644

0 commit comments

Comments
 (0)