Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions caffe2/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ if (NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
${TORCH_SRC_DIR}/csrc/jit/passes/utils/check_alias_annotation.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/utils/memory_dag.cpp
${TORCH_SRC_DIR}/csrc/jit/passes/quantization.cpp
${TORCH_SRC_DIR}/csrc/jit/print_handler.cpp
${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
Expand Down
22 changes: 22 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12929,6 +12929,28 @@ def forward(self, key):
.check_count(s2, 1, exactly=True) \
.check_count("BINGET", 2, exactly=True).run(out.getvalue())

def test_sys_stdout_override(self):
@torch.jit.script
def foo():
print('foo')

class Redirect(object):
def __init__(self):
self.s = ''

def write(self, s):
self.s += s

old_stdout = sys.stdout
redirect = Redirect()
try:
sys.stdout = redirect
foo()
finally:
sys.stdout = old_stdout

FileCheck().check('foo').run(redirect.s)

def test_optional_tuple(self):
def fn(x=None):
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
Expand Down
1 change: 1 addition & 0 deletions tools/build_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"torch/csrc/jit/passes/subgraph_rewrite.cpp",
"torch/csrc/jit/passes/utils/subgraph_utils.cpp",
"torch/csrc/jit/passes/utils/memory_dag.cpp",
"torch/csrc/jit/print_handler.cpp",
"torch/csrc/jit/register_prim_ops.cpp",
"torch/csrc/jit/register_special_ops.cpp",
"torch/csrc/jit/register_quantized_ops.cpp",
Expand Down
13 changes: 13 additions & 0 deletions torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <torch/csrc/jit/passes/specialize_autogradzero.h>
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
#include <torch/csrc/jit/passes/utils/check_alias_annotation.h>
#include <torch/csrc/jit/print_handler.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/python_arg_flatten.h>
#include <torch/csrc/jit/python_ir.h>
Expand All @@ -46,13 +47,15 @@
#include <torch/csrc/jit/script/module.h>
#include <torch/csrc/jit/script/python_tree_views.h>
#include <torch/csrc/jit/tracer.h>
#include <torch/csrc/utils/auto_gil.h>

#include <c10/macros/Export.h>
#include <caffe2/serialize/inline_container.h>

#include <ATen/core/function_schema.h>

#include <pybind11/functional.h>
#include <pybind11/iostream.h>

#include <memory>
#include <sstream>
Expand Down Expand Up @@ -563,6 +566,16 @@ void initJITBindings(PyObject* module) {
tracer::initPythonTracerBindings(module);
script::initTreeViewBindings(module);
script::initJitScriptBindings(module);

setPrintHandler([](const std::string& str) {
py::gil_scoped_acquire acquire;
try {
auto _stdout = py::module::import("sys").attr("stdout");
_stdout.attr("write")(str);
} catch (py::error_already_set& e) {
throw std::runtime_error(e.what());
}
});
}
} // namespace jit
} // namespace torch
22 changes: 22 additions & 0 deletions torch/csrc/jit/print_handler.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include <torch/csrc/jit/print_handler.h>

#include <iostream>
#include <string>

namespace torch {
namespace jit {

std::atomic<PrintHandler> print_handler([](const std::string& str) {
std::cout << str;
});

PrintHandler getPrintHandler() {
return print_handler.load();
}

void setPrintHandler(PrintHandler ph) {
print_handler.store(ph);
}

} // namespace jit
} // namespace torch
18 changes: 18 additions & 0 deletions torch/csrc/jit/print_handler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include <torch/csrc/WindowsTorchApiMacro.h>

#include <atomic>
#include <functional>
#include <iostream>

namespace torch {
namespace jit {

using PrintHandler = void (*)(const std::string&);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::function?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nah brah


TORCH_API void setPrintHandler(PrintHandler ph);
TORCH_API PrintHandler getPrintHandler();

} // namespace jit
} // namespace torch
11 changes: 8 additions & 3 deletions torch/csrc/jit/register_prim_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/pickler.h>
#include <torch/csrc/jit/print_handler.h>
#include <torch/csrc/jit/profiling_record.h>
#include <torch/csrc/jit/script/compilation_unit.h>
#include <torch/csrc/jit/script/error_report.h>
Expand Down Expand Up @@ -544,15 +545,19 @@ RegisterOperators reg(
[](const Node* node) {
size_t num_inputs = node->inputs().size();
return [num_inputs](Stack& stack) {
std::stringstream ss;
bool first = true;
for (const IValue& i : last(stack, num_inputs)) {
if (!first)
std::cout << " ";
ss << " ";
first = false;
std::cout << i;
ss << i;
}
drop(stack, num_inputs);
std::cout << std::endl;
ss << std::endl;
auto* handler = getPrintHandler();
TORCH_INTERNAL_ASSERT(handler);
handler(ss.str());
return 0;
};
}),
Expand Down