Skip to content

Commit c2a18a6

Browse files
James Reedfacebook-github-bot
authored andcommitted
Override print when python is present (#21625)
Summary: This makes it so we can see the output of prim::Print in environments like iPython notebooks which override sys.stdout Pull Request resolved: #21625 Differential Revision: D15756793 Pulled By: jamesr66a fbshipit-source-id: 7d9a14b2e229ed358e784318e9d862677db2c461
1 parent aa7e27f commit c2a18a6

File tree

7 files changed

+85
-3
lines changed

7 files changed

+85
-3
lines changed

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
417417
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
418418
${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
419419
${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
420+
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
420421
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
421422
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
422423
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp

test/test_jit.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12983,6 +12983,28 @@ def forward(self, key):
1298312983
.check_count(s2, 1, exactly=True) \
1298412984
.check_count("BINGET", 2, exactly=True).run(out.getvalue())
1298512985

12986+
def test_sys_stdout_override(self):
12987+
@torch.jit.script
12988+
def foo():
12989+
print('foo')
12990+
12991+
class Redirect(object):
12992+
def __init__(self):
12993+
self.s = ''
12994+
12995+
def write(self, s):
12996+
self.s += s
12997+
12998+
old_stdout = sys.stdout
12999+
redirect = Redirect()
13000+
try:
13001+
sys.stdout = redirect
13002+
foo()
13003+
finally:
13004+
sys.stdout = old_stdout
13005+
13006+
FileCheck().check('foo').run(redirect.s)
13007+
1298613008
def test_optional_tuple(self):
1298713009
def fn(x=None):
1298813010
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]

tools/build_variables.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
105105
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
106106
"torch/csrc/jit/passes/utils/memory_dag.cpp",
107+
"torch/csrc/jit/print_handler.cpp",
107108
"torch/csrc/jit/register_prim_ops.cpp",
108109
"torch/csrc/jit/register_special_ops.cpp",
109110
"torch/csrc/jit/register_quantized_ops.cpp",

torch/csrc/jit/init.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
3737
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
3838
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
39+
#include <torch/csrc/jit/print_handler.h>
3940
#include <torch/csrc/jit/pybind_utils.h>
4041
#include <torch/csrc/jit/python_arg_flatten.h>
4142
#include <torch/csrc/jit/python_ir.h>
@@ -46,13 +47,15 @@
4647
#include <torch/csrc/jit/script/module.h>
4748
#include <torch/csrc/jit/script/python_tree_views.h>
4849
#include <torch/csrc/jit/tracer.h>
50+
#include <torch/csrc/utils/auto_gil.h>
4951

5052
#include <c10/macros/Export.h>
5153
#include <caffe2/serialize/inline_container.h>
5254

5355
#include <ATen/core/function_schema.h>
5456

5557
#include <pybind11/functional.h>
58+
#include <pybind11/iostream.h>
5659

5760
#include <memory>
5861
#include <sstream>
@@ -563,6 +566,16 @@ void initJITBindings(PyObject* module) {
563566
tracer::initPythonTracerBindings(module);
564567
script::initTreeViewBindings(module);
565568
script::initJitScriptBindings(module);
569+
570+
setPrintHandler([](const std::string& str) {
571+
py::gil_scoped_acquire acquire;
572+
try {
573+
auto _stdout = py::module::import("sys").attr("stdout");
574+
_stdout.attr("write")(str);
575+
} catch (py::error_already_set& e) {
576+
throw std::runtime_error(e.what());
577+
}
578+
});
566579
}
567580
} // namespace jit
568581
} // namespace torch

torch/csrc/jit/print_handler.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <torch/csrc/jit/print_handler.h>
2+
3+
#include <iostream>
4+
#include <string>
5+
6+
namespace torch {
7+
namespace jit {
8+
9+
std::atomic<PrintHandler> print_handler([](const std::string& str) {
10+
std::cout << str;
11+
});
12+
13+
PrintHandler getPrintHandler() {
14+
return print_handler.load();
15+
}
16+
17+
void setPrintHandler(PrintHandler ph) {
18+
print_handler.store(ph);
19+
}
20+
21+
} // namespace jit
22+
} // namespace torch

torch/csrc/jit/print_handler.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#pragma once
2+
3+
#include <torch/csrc/WindowsTorchApiMacro.h>
4+
5+
#include <atomic>
6+
#include <functional>
7+
#include <iostream>
8+
9+
namespace torch {
10+
namespace jit {
11+
12+
using PrintHandler = void (*)(const std::string&);
13+
14+
TORCH_API void setPrintHandler(PrintHandler ph);
15+
TORCH_API PrintHandler getPrintHandler();
16+
17+
} // namespace jit
18+
} // namespace torch

torch/csrc/jit/register_prim_ops.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/csrc/jit/ir.h>
1111
#include <torch/csrc/jit/operator.h>
1212
#include <torch/csrc/jit/pickler.h>
13+
#include <torch/csrc/jit/print_handler.h>
1314
#include <torch/csrc/jit/profiling_record.h>
1415
#include <torch/csrc/jit/script/compilation_unit.h>
1516
#include <torch/csrc/jit/script/error_report.h>
@@ -552,15 +553,19 @@ RegisterOperators reg(
552553
[](const Node* node) {
553554
size_t num_inputs = node->inputs().size();
554555
return [num_inputs](Stack& stack) {
556+
std::stringstream ss;
555557
bool first = true;
556558
for (const IValue& i : last(stack, num_inputs)) {
557559
if (!first)
558-
std::cout << " ";
560+
ss << " ";
559561
first = false;
560-
std::cout << i;
562+
ss << i;
561563
}
562564
drop(stack, num_inputs);
563-
std::cout << std::endl;
565+
ss << std::endl;
566+
auto* handler = getPrintHandler();
567+
TORCH_INTERNAL_ASSERT(handler);
568+
handler(ss.str());
564569
return 0;
565570
};
566571
}),

0 commit comments

Comments
 (0)