Skip to content

Commit d089ea3

Browse files
author
James Reed
committed
switch
1 parent 79b13d9 commit d089ea3

File tree

4 files changed

+20
-5
lines changed

4 files changed

+20
-5
lines changed

test/test_jit.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12929,6 +12929,22 @@ def forward(self, key):
1292912929
.check_count(s2, 1, exactly=True) \
1293012930
.check_count("BINGET", 2, exactly=True).run(out.getvalue())
1293112931

12932+
def test_sys_stdout_override(self):
12933+
@torch.jit.script
12934+
def foo():
12935+
print('foo')
12936+
12937+
old_stdout = sys.stdout
12938+
with tempfile.TemporaryFile() as f:
12939+
try:
12940+
sys.stdout = f
12941+
foo()
12942+
finally:
12943+
sys.stdout = old_stdout
12944+
12945+
f.seek(0)
12946+
FileCheck().check('foo').run(f.read())
12947+
1293212948
def test_optional_tuple(self):
1293312949
def fn(x=None):
1293412950
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]

torch/csrc/jit/init.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@
5454

5555
#include <ATen/core/function_schema.h>
5656

57-
#include <Python.h>
5857
#include <pybind11/functional.h>
58+
#include <pybind11/iostream.h>
5959

6060
#include <memory>
6161
#include <sstream>
@@ -570,7 +570,7 @@ void initJITBindings(PyObject* module) {
570570
setPrintHandler([](const std::string& str) {
571571
py::gil_scoped_acquire acquire;
572572
auto _stdout = py::module::import("sys").attr("stdout");
573-
_stdout.attr("write")(str);
573+
_stdout.attr("write")(py::bytes(str));
574574
});
575575
}
576576
} // namespace jit

torch/csrc/jit/init.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#pragma once
22

3-
#include <functional>
4-
#include <iostream>
5-
63
namespace torch {
74
namespace jit {
85

torch/csrc/jit/print_handler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <torch/csrc/jit/print_handler.h>
22

3+
#include <iostream>
4+
35
namespace torch {
46
namespace jit {
57

0 commit comments

Comments
 (0)