|
| 1 | +#include "torch/csrc/onnx/init.h" |
| 2 | +#include "torch/csrc/onnx/onnx.pb.h" |
| 3 | + |
| 4 | +namespace torch { namespace onnx { |
| 5 | +void initONNXBindings(PyObject* module) { |
| 6 | + auto m = py::handle(module).cast<py::module>(); |
| 7 | + auto onnx = m.def_submodule("_onnx"); |
| 8 | + py::enum_<onnx_TensorProto_DataType>(onnx, "TensorProtoDataType") |
| 9 | + .value("UNDEFINED", onnx_TensorProto_DataType_UNDEFINED) |
| 10 | + .value("FLOAT", onnx_TensorProto_DataType_FLOAT) |
| 11 | + .value("UINT8", onnx_TensorProto_DataType_UINT8) |
| 12 | + .value("INT8", onnx_TensorProto_DataType_INT8) |
| 13 | + .value("UINT16", onnx_TensorProto_DataType_UINT16) |
| 14 | + .value("INT16", onnx_TensorProto_DataType_INT16) |
| 15 | + .value("INT32", onnx_TensorProto_DataType_INT32) |
| 16 | + .value("INT64", onnx_TensorProto_DataType_INT64) |
| 17 | + .value("STRING", onnx_TensorProto_DataType_STRING) |
| 18 | + .value("BOOL", onnx_TensorProto_DataType_BOOL) |
| 19 | + .value("FLOAT16", onnx_TensorProto_DataType_FLOAT16) |
| 20 | + .value("DOUBLE", onnx_TensorProto_DataType_DOUBLE) |
| 21 | + .value("UINT32", onnx_TensorProto_DataType_UINT32) |
| 22 | + .value("UINT64", onnx_TensorProto_DataType_UINT64) |
| 23 | + .value("COMPLEX64", onnx_TensorProto_DataType_COMPLEX64) |
| 24 | + .value("COMPLEX128", onnx_TensorProto_DataType_COMPLEX128); |
| 25 | +} |
| 26 | +}} // namespace torch::onnx |
0 commit comments