Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions test/jit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,18 @@ def disableEmitHook(self):
yield None
self.setHooks()

def _isHookExceptionOk(self, e):
se = str(e)
allowed = ("Could not export Python function",
"closures are not exportable")
for a in allowed:
if a in se:
return True
return False

def emitFunctionHook(self, func):
# func has invalid names for export, skip the jitter check
if func.name == "<lambda>" or "aten::" in func.name:
if func.name == "<lambda>" or "aten::" in func.name or _in_first_class_mode:
return
# disable the hook while we parse code, otherwise we will re-enter the hook
with self.disableEmitHook():
Expand All @@ -72,9 +81,7 @@ def emitFunctionHook(self, func):
src2, constants2 = _jit_python_print(func2)
self.assertMultiLineEqual(src, src2)
except RuntimeError as e:
se = str(e)
if "Could not export Python function" not in se and \
"closures are not exportable" not in se:
if not self._isHookExceptionOk(e):
raise

def emitModuleHook(self, module):
Expand Down Expand Up @@ -113,9 +120,7 @@ def copy_structure_and_params(m):
for line in main_module:
main_module_code += line.decode()
except RuntimeError as e:
se = str(e)
if "Could not export Python function" not in se and \
"closures are not exportable" not in se:
if not self._isHookExceptionOk(e):
raise
else:
return
Expand Down Expand Up @@ -428,6 +433,18 @@ def enable_profiling_mode():
yield
torch._C._jit_set_profiling_mode(False)


_in_first_class_mode = False
@contextmanager
def enable_first_class_mode():
global _in_first_class_mode
torch._C._jit_set_first_class_mode(True)
_in_first_class_mode = True
yield
torch._C._jit_set_first_class_mode(False)
_in_first_class_mode = False


# note: not re-entrant, use unnested only
@contextmanager
def disable_autodiff_subgraph_inlining(enabled=True):
Expand Down
28 changes: 22 additions & 6 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state, set_rng_seed, slowTest, TemporaryFileName
from jit_utils import JitTestCase, enable_cpu_fuser, disable_autodiff_subgraph_inlining, \
_trace, enable_cpu_fuser_if, enable_profiling_mode
_trace, enable_cpu_fuser_if, enable_profiling_mode, enable_first_class_mode
from common_nn import module_tests, new_module_tests, criterion_tests
from common_methods_invocations import method_tests as autograd_method_tests
from common_methods_invocations import create_input, unpack_variables, \
Expand Down Expand Up @@ -230,11 +230,6 @@ def _sum_of_list(tensorlist):
s += t.sum()
return s

@contextmanager
def enable_first_class_mode():
torch._C._jit_set_first_class_mode(True)
yield
torch._C._jit_set_first_class_mode(False)

# helper function to generate test qparam
def _helper_generate_qparam(script_module, input_data):
Expand Down Expand Up @@ -2991,6 +2986,27 @@ def forward(self, input):
foo.forward(input)
self.assertEqual(input, foo.foo)

def test_first_class_calls(self):
with enable_first_class_mode():
@torch.jit.script
class Foo(object):
def __init__(self, x):
self.bar = x

def stuff(self, x):
return self.bar + x

@torch.jit.script
def foo(x):
return x * x + Foo(x).stuff(2 * x)

@torch.jit.script
def bar(x):
return foo(x) * foo(x)

x = torch.rand(3, 4)
self.assertEqual(bar(x), (x * x + 3 * x) * (x * x + 3 * x))

def test_invalid_prefix_annotation(self):
with self.assertRaisesRegex(RuntimeError, "annotation prefix in line"):
with self.capture_stdout() as captured:
Expand Down
1 change: 0 additions & 1 deletion torch/csrc/jit/argument_spec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,6 @@ ArgumentSpec ArgumentSpecCreator::create(bool with_grad, const Stack& input)
// consume object
const IValue* iv = stack[stack_top]++;
AT_ASSERT(iv->isObject());
iv->toObject();
// see [argspec refcounting]
auto p = *reinterpret_cast<const at::ivalue::Object* const*>(iv);
auto obj_ptr = &p->slots()[0];
Expand Down
31 changes: 31 additions & 0 deletions torch/csrc/jit/exception_message.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#pragma once
#include <c10/util/Exception.h>
#include <stdexcept>

namespace torch {
namespace jit {

struct ExceptionMessage {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment explaining what purpose this serves? Why cut out the c10 backtarce is not obvious I think.

ExceptionMessage(const std::exception& e) : e_(e) {}

private:
const std::exception& e_;
friend std::ostream& operator<<(
std::ostream& out,
const ExceptionMessage& msg);
};

inline std::ostream& operator<<(
std::ostream& out,
const ExceptionMessage& msg) {
auto c10_error = dynamic_cast<const c10::Error*>(&msg.e_);
if (c10_error) {
out << c10_error->msg_without_backtrace();
} else {
out << msg.e_.what();
}
return out;
}

} // namespace jit
} // namespace torch
4 changes: 4 additions & 0 deletions torch/csrc/jit/graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ void GraphExecutor::run(Stack& inputs) {
return pImpl->run(inputs);
}

ExecutionPlan GraphExecutor::getPlanFor(Stack& inputs) {
return pImpl->getPlanFor(inputs);
}

std::shared_ptr<Graph> GraphExecutor::graph() const {
return pImpl->graph;
}
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/jit/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct TORCH_API GraphExecutor {
GraphExecutor() = default;
GraphExecutor(std::shared_ptr<Graph> graph, bool optimize = true);
void run(Stack& inputs);
ExecutionPlan getPlanFor(Stack& inputs);
explicit operator bool() const {
return pImpl != nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/jit/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ void initJITBindings(PyObject* module) {
[](bool profiling_flag) { getProfilingMode() = profiling_flag; })
.def(
"_jit_set_first_class_mode",
[](bool enabled) { script::setRunAsFirstClass(enabled); })
[](bool enabled) { script::getFirstClassMode() = enabled; })
.def(
"_jit_fuser_get_fused_kernel_code",
[](Graph& g, std::vector<at::Tensor> inps) {
Expand Down
Loading