Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"python.linting.flake8Enabled": true,
"python.linting.mypyEnabled": false,
"python.linting.pydocstyleEnabled": false,
"python.linting.enabled": true
}
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
# reproducible builds (https://github.com/pybind/python_example/pull/53)

ext_modules = [
Pybind11Extension("python_example",
Pybind11Extension(
"python_example",
["src/main.cpp"],
# Example: passing in the version to the compiled code
define_macros = [('VERSION_INFO', __version__)],
),
define_macros=[('VERSION_INFO', __version__)],
),
]

setup(
Expand Down
100 changes: 86 additions & 14 deletions src/main.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,57 @@
#include <pybind11/pybind11.h>

#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#include <cstdint>
#include <iostream>

struct _SupplementBase {};

struct NcclPreMulSumSupplement : _SupplementBase {};

struct ReduceOp {
enum Kind : uint8_t {
SUM = 0,
PRODUCT,
MIN,
MAX,
BAND, // Bitwise AND
BOR, // Bitwise OR
BXOR, // Bitwise XOR
UNUSED,
};

// ReduceOp() = delete;
ReduceOp() = default;
ReduceOp(Kind op) : op{op} {}

int add(int i, int j) {
return i + j;
static ReduceOp create(Kind op) {
return ReduceOp(op);
}

Kind op = SUM;
};

void func(const ReduceOp a, const ReduceOp b) {
return;
}

void func(const ReduceOp::Kind a, const ReduceOp::Kind b) {
func(ReduceOp(a), ReduceOp(b));
}

struct Options {
Options() {}
Options(ReduceOp reduce_op) : reduce_op{reduce_op} {}
Options(ReduceOp::Kind reduce_op_kind) : reduce_op(reduce_op_kind) {}
ReduceOp reduce_op;
};

void func(Options a, Options b) {
return;
}

#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)

namespace py = pybind11;

PYBIND11_MODULE(python_example, m) {
Expand All @@ -19,21 +64,48 @@ PYBIND11_MODULE(python_example, m) {
.. autosummary::
:toctree: _generate

add
subtract
ReduceOp
)pbdoc";

m.def("add", &add, R"pbdoc(
Add two numbers

Some other explanation about the add function.
// Reference: https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods
m.def("func", py::overload_cast<const ReduceOp, const ReduceOp>(&func), R"pbdoc(
Takes two ReduceOp
)pbdoc");
m.def("func", py::overload_cast<const ReduceOp::Kind, const ReduceOp::Kind>(&func), R"pbdoc(
Takes two ReduceOp
)pbdoc");
m.def("func", py::overload_cast<Options, Options>(&func), R"pbdoc(Takes two Options)pbdoc");

m.def("subtract", [](int i, int j) { return i - j; }, R"pbdoc(
Subtract two numbers
// m.def("func", static_cast<void ()(const ReduceOp::Kind, const ReduceOp::Kind)>(&func), R"pbdoc(
// Takes two ReduceOp
// )pbdoc");

Some other explanation about the subtract function.
)pbdoc");
py::class_<ReduceOp> reduce_op(m, "ReduceOp");

reduce_op
.def(py::init<ReduceOp::Kind>())
.def_readwrite("op", &ReduceOp::op);

py::enum_<ReduceOp::Kind>(reduce_op, "Kind")
.value("SUM", ReduceOp::SUM)
.value("PRODUCT", ReduceOp::PRODUCT)
.value("MIN", ReduceOp::MIN)
.value("MAX", ReduceOp::MAX)
.value("BAND", ReduceOp::BAND)
.value("BOR", ReduceOp::BOR)
.value("BXOR", ReduceOp::BXOR)
.value("UNUSED", ReduceOp::UNUSED)
.export_values();

py::class_<Options> options(m, "Options");
options
.def(py::init<>())
.def(py::init<ReduceOp>())
.def(py::init<ReduceOp::Kind>())
.def_readwrite("reduce_op", &Options::reduce_op);

// Ref: [Implicit conversions](https://pybind11.readthedocs.io/en/stable/advanced/classes.html#implicit-conversions)
py::implicitly_convertible<ReduceOp::Kind, ReduceOp>();

#ifdef VERSION_INFO
m.attr("__version__") = MACRO_STRINGIFY(VERSION_INFO);
Expand Down
43 changes: 40 additions & 3 deletions tests/test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,42 @@
import unittest

import python_example as m

assert m.__version__ == '0.0.1'
assert m.add(1, 2) == 3
assert m.subtract(1, 2) == -1

class SimpleTest(unittest.TestCase):

def test_func(self):
m.func(m.ReduceOp(m.ReduceOp.SUM), m.ReduceOp(m.ReduceOp.UNUSED))

def test_func_auto_call_ctor(self):
m.func(m.ReduceOp.SUM, m.ReduceOp.UNUSED)

def test_options_op_assignment_from_internal_enum(self):
# ======================================================================
# ERROR: test_options_op_assignment_from_internal_enum (__main__.SimpleTest)
# ----------------------------------------------------------------------
# Traceback (most recent call last):
# File "/home/masaki/ghq/github.com/crcrpar/python_example/tests/test.py", line 16, in test_options_op_assignment_from_internal_enum
# opt_a.reduce_op = m.ReduceOp.SUM
# TypeError: (): incompatible function arguments. The following argument types are supported:
# 1. (self: python_example.Options, arg0: python_example.ReduceOp) -> None

# Invoked with: <python_example.Options object at 0x7f4fe49adcf0>, <Kind.SUM: 0>

# ----------------------------------------------------------------------
# Ran 3 tests in 0.000s
#
# FAILED (errors=1)
# note(crcrpar): Implcicit conversion is what might play a role?
Comment on lines +15 to +30
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ameliorate this comment.

opt_a = m.Options()
opt_a.reduce_op = m.ReduceOp.SUM
m.func(opt_a, opt_a)

def test_options_op_assignment_from_internal_enum_WAR(self):
opt_a = m.Options()
opt_a.reduce_op = m.ReduceOp(m.ReduceOp.SUM)
m.func(opt_a, opt_a)


if __name__ == "__main__":
unittest.main()