|
36 | 36 | #include <torch/csrc/jit/passes/specialize_autogradzero.h> |
37 | 37 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
38 | 38 | #include <torch/csrc/jit/passes/utils/check_alias_annotation.h> |
| 39 | +#include <torch/csrc/jit/print_handler.h> |
39 | 40 | #include <torch/csrc/jit/pybind_utils.h> |
40 | 41 | #include <torch/csrc/jit/python_arg_flatten.h> |
41 | 42 | #include <torch/csrc/jit/python_ir.h> |
|
46 | 47 | #include <torch/csrc/jit/script/module.h> |
47 | 48 | #include <torch/csrc/jit/script/python_tree_views.h> |
48 | 49 | #include <torch/csrc/jit/tracer.h> |
| 50 | +#include <torch/csrc/utils/auto_gil.h> |
49 | 51 |
|
50 | 52 | #include <c10/macros/Export.h> |
51 | 53 | #include <caffe2/serialize/inline_container.h> |
52 | 54 |
|
53 | 55 | #include <ATen/core/function_schema.h> |
54 | 56 |
|
55 | 57 | #include <pybind11/functional.h> |
| 58 | +#include <pybind11/iostream.h> |
56 | 59 |
|
57 | 60 | #include <memory> |
58 | 61 | #include <sstream> |
@@ -563,6 +566,16 @@ void initJITBindings(PyObject* module) { |
563 | 566 | tracer::initPythonTracerBindings(module); |
564 | 567 | script::initTreeViewBindings(module); |
565 | 568 | script::initJitScriptBindings(module); |
| 569 | + |
| 570 | + setPrintHandler([](const std::string& str) { |
| 571 | + py::gil_scoped_acquire acquire; |
| 572 | + try { |
| 573 | + auto _stdout = py::module::import("sys").attr("stdout"); |
| 574 | + _stdout.attr("write")(str); |
| 575 | + } catch (py::error_already_set& e) { |
| 576 | + throw std::runtime_error(e.what()); |
| 577 | + } |
| 578 | + }); |
566 | 579 | } |
567 | 580 | } // namespace jit |
568 | 581 | } // namespace torch |
0 commit comments