forked from pybind/python_example
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.cpp
More file actions
115 lines (90 loc) · 2.84 KB
/
main.cpp
File metadata and controls
115 lines (90 loc) · 2.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#include <pybind11/pybind11.h>
#include <cstdint>
#include <iostream>
struct _SupplementBase {};
struct NcclPreMulSumSupplement : _SupplementBase {};
struct ReduceOp {
enum Kind : uint8_t {
SUM = 0,
PRODUCT,
MIN,
MAX,
BAND, // Bitwise AND
BOR, // Bitwise OR
BXOR, // Bitwise XOR
UNUSED,
};
// ReduceOp() = delete;
ReduceOp() = default;
ReduceOp(Kind op) : op{op} {}
static ReduceOp create(Kind op) {
return ReduceOp(op);
}
Kind op = SUM;
};
void func(const ReduceOp a, const ReduceOp b) {
return;
}
void func(const ReduceOp::Kind a, const ReduceOp::Kind b) {
func(ReduceOp(a), ReduceOp(b));
}
struct Options {
Options() {}
Options(ReduceOp reduce_op) : reduce_op{reduce_op} {}
Options(ReduceOp::Kind reduce_op_kind) : reduce_op(reduce_op_kind) {}
ReduceOp reduce_op;
};
void func(Options a, Options b) {
return;
}
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
namespace py = pybind11;
PYBIND11_MODULE(python_example, m) {
m.doc() = R"pbdoc(
Pybind11 example plugin
-----------------------
.. currentmodule:: python_example
.. autosummary::
:toctree: _generate
ReduceOp
)pbdoc";
// Reference: https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods
m.def("func", py::overload_cast<const ReduceOp, const ReduceOp>(&func), R"pbdoc(
Takes two ReduceOp
)pbdoc");
m.def("func", py::overload_cast<const ReduceOp::Kind, const ReduceOp::Kind>(&func), R"pbdoc(
Takes two ReduceOp
)pbdoc");
m.def("func", py::overload_cast<Options, Options>(&func), R"pbdoc(Takes two Options)pbdoc");
// m.def("func", static_cast<void ()(const ReduceOp::Kind, const ReduceOp::Kind)>(&func), R"pbdoc(
// Takes two ReduceOp
// )pbdoc");
py::class_<ReduceOp> reduce_op(m, "ReduceOp");
reduce_op
.def(py::init<ReduceOp::Kind>())
.def_readwrite("op", &ReduceOp::op);
py::enum_<ReduceOp::Kind>(reduce_op, "Kind")
.value("SUM", ReduceOp::SUM)
.value("PRODUCT", ReduceOp::PRODUCT)
.value("MIN", ReduceOp::MIN)
.value("MAX", ReduceOp::MAX)
.value("BAND", ReduceOp::BAND)
.value("BOR", ReduceOp::BOR)
.value("BXOR", ReduceOp::BXOR)
.value("UNUSED", ReduceOp::UNUSED)
.export_values();
py::class_<Options> options(m, "Options");
options
.def(py::init<>())
.def(py::init<ReduceOp>())
.def(py::init<ReduceOp::Kind>())
.def_readwrite("reduce_op", &Options::reduce_op);
// Ref: [Implicit conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions)
py::implicitly_convertible<ReduceOp::Kind, ReduceOp>();
#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
#else
m.attr("__version__") = "dev";
#endif
}