Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16185,7 +16185,7 @@ def _xor(): # noqa: E306
_or, _xor]:
self.checkScript(func, ())

with self.assertRaisesRegex(RuntimeError, "because it does not define a __add__"):
with self.assertRaisesRegex(RuntimeError, "__add__ method"):
@torch.jit.script
def test():
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))
Expand Down
35 changes: 29 additions & 6 deletions torch/csrc/jit/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
}

SourceRange Node::sourceRange() const {
if(source_range_) {
return *source_range_;
}
std::stringstream ss;
return SourceRange(ss.str());
if (source_range_) {
return *source_range_;
}
std::stringstream ss;
return SourceRange(ss.str());
}

static std::ostream& indent(std::ostream& out, size_t level) {
Expand Down Expand Up @@ -236,7 +236,6 @@ std::ostream& Node::print(

groups->push_back(this);
} else {

out << kind().toQualString();
if (hasAttributes()) {
printAttributes(out);
Expand Down Expand Up @@ -1429,6 +1428,30 @@ Node* Graph::createGetAttr(Value* obj, const std::string& field) {
return n;
}

Value* Graph::insertFunctionCall(
std::shared_ptr<script::Function> callee,
script::MatchedSchema& matched) {
Value* fn_constant = insertNode(create(prim::Constant))
->output()
->setType(FunctionType::create(std::move(callee)));
std::vector<Value*> inputs = {fn_constant};
inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
Value* result = insertNode(create(prim::CallFunction, inputs))
->output()
->setType(matched.return_types.at(0));
return result;
}

Value* Graph::insertMethodCall(
std::string method_name,
script::MatchedSchema& matched) {
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
->s_(attr::name, std::move(method_name))
->output()
->setType(matched.return_types.at(0));
return result;
}

Node* Graph::createClone(
Node* n,
const std::function<Value*(Value*)>& value_map,
Expand Down
12 changes: 12 additions & 0 deletions torch/csrc/jit/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ namespace aten {
using namespace ::c10::aten;
}

namespace script {
struct Function;
struct MatchedSchema;
} // namespace script

// Graph represents one "function" of computation.
// It uses a simple ownership model where the graph owns all the nodes inside
// it. All references inside the graph are raw pointers. Destroying the Graph
Expand Down Expand Up @@ -1093,6 +1098,13 @@ struct Graph {
return insertNode(createGetAttr(obj, field))->output();
}

TORCH_API Value* insertFunctionCall(
std::shared_ptr<script::Function> callee,
script::MatchedSchema& matched);
TORCH_API Value* insertMethodCall(
std::string method_name,
script::MatchedSchema& matched);

// Note: defined in python_ir.cpp and can be used only in python extension
Node* createPythonOp(
THPObjectPtr&& pyobj,
Expand Down
73 changes: 40 additions & 33 deletions torch/csrc/jit/passes/inliner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,48 @@ namespace prim {
using namespace ::c10::prim;
}

void inlineCalls(Block* block) {
Node* cur = block->nodes().front();
Node* end = block->return_node();

while (cur != end) {
auto next = cur->next();
for (auto b : cur->blocks()) {
inlineCalls(b);
}
if (cur->kind() == prim::CallFunction) {
AT_ASSERT(cur->inputs().at(0)->node()->kind() == prim::Constant);
auto function_constant = cur->inputs().at(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
auto graph = fun_type->function()->graph();

auto old_output = cur->outputs();
// slice function ptr value
auto inputs = cur->inputs().slice(1);
WithInsertPoint guard(next);
auto new_output =
inlineCallTo(*cur->owningGraph(), *graph.get(), inputs).at(0);
if (old_output.at(0)->hasUniqueName()) {
auto name = old_output.at(0)->uniqueName();
new_output->setUniqueName(name);
}
static void replace(
Node* to_replace,
const std::shared_ptr<script::Function>& fn,
at::ArrayRef<Value*> inputs) {
WithInsertPoint guard(to_replace);
auto new_output =
inlineCallTo(*to_replace->owningGraph(), *fn->graph(), inputs).at(0);
if (to_replace->output()->hasUniqueName()) {
new_output->setUniqueName(to_replace->output()->uniqueName());
}
to_replace->output()->replaceAllUsesWith(new_output);
}

old_output.at(0)->replaceAllUsesWith(new_output);
next = cur->next();
cur->destroy();
if (!function_constant->hasUses()) {
function_constant->destroy();
}
void inlineCalls(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
Node* cur = *it++;
switch (cur->kind()) {
case prim::CallFunction: {
AT_ASSERT(cur->inputs().at(0)->node()->kind() == prim::Constant);
auto function_constant = cur->inputs().at(0)->node();
auto fun_type =
function_constant->output()->type()->expect<FunctionType>();
replace(cur, fun_type->function(), cur->inputs().slice(1));
cur->destroy();
if (!function_constant->hasUses()) {
function_constant->destroy();
}
} break;
case prim::CallMethod: {
const std::string& name = cur->s(attr::name);
auto function =
cur->inputs().at(0)->type()->expect<ClassType>()->getMethod(name);
replace(cur, function, cur->inputs());
cur->destroy();
} break;
default: {
for (auto b : cur->blocks()) {
inlineCalls(b);
}
} break;
}
cur = next;
}
}

Expand Down
15 changes: 9 additions & 6 deletions torch/csrc/jit/script/builtin_functions.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#include <torch/csrc/jit/script/builtin_functions.h>
#include <torch/csrc/api/include/torch/jit.h>
#include <torch/csrc/jit/code_template.h>
#include <torch/csrc/jit/script/builtin_functions.h>
#include <torch/csrc/jit/script/resolver.h>

namespace torch {
Expand Down Expand Up @@ -38,8 +38,9 @@ def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
)SCRIPT");

struct BuiltinFunctionRegistry {
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
const static std::vector<Function*> empty;
const std::vector<std::shared_ptr<Function>>& getAllBuiltinFunctionsFor(
Symbol name) {
const static std::vector<std::shared_ptr<Function>> empty;
// when initializing the builtin function library, we will re-enter
// getAllBuiltinFunctionsFor since it is called in the compiler to
// lookup builtins and initializing the builtin functions calls the
Expand Down Expand Up @@ -68,7 +69,7 @@ struct BuiltinFunctionRegistry {
cu->define(source, script::nativeResolver(), /*self=*/nullptr);
for (auto& method : cu->get_functions()) {
builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
.push_back(method.get());
.push_back(method);
}
}
void loadBuiltinFunctions() {
Expand Down Expand Up @@ -98,10 +99,12 @@ struct BuiltinFunctionRegistry {
enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
std::recursive_mutex mutex;
std::vector<std::shared_ptr<CompilationUnit>> modules;
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
std::unordered_map<Symbol, std::vector<std::shared_ptr<Function>>>
builtins_by_name;
};

const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
const std::vector<std::shared_ptr<Function>>& getAllBuiltinFunctionsFor(
Symbol name) {
static BuiltinFunctionRegistry registry;
return registry.getAllBuiltinFunctionsFor(name);
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/script/builtin_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace torch {
namespace jit {
namespace script {

TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);

TORCH_API const std::vector<std::shared_ptr<Function>>&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we change these to owning pointers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the FunctionType object captures them as such when an emitCall happens.

getAllBuiltinFunctionsFor(Symbol name);
}
} // namespace jit
} // namespace torch
21 changes: 1 addition & 20 deletions torch/csrc/jit/script/compilation_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,6 @@ struct TORCH_API Function : public std::enable_shared_from_this<Function> {
return executor_;
}

// returns nullptr and fills in failure_messages if the callee does not
// match the functions schema

// TODO: defined in module.cpp, move to compilation_unit.cpp
Value* try_emit_call(
Graph& graph,
const SourceRange& loc,
c10::optional<NamedValue> self,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs,
std::ostream* failure_messages,
bool conv_tensors_to_nums);

Value* emit_call(
Graph& graph,
const SourceRange& loc,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs);

private:
static FunctionSchema defaultSchemaFor(const Function& function) {
std::vector<Argument> args;
Expand Down Expand Up @@ -167,7 +148,7 @@ struct TORCH_API Function : public std::enable_shared_from_this<Function> {
std::once_flag executor_init_;

// an optional function that actually creates the method when
// emit_call_to(this,...) is first called. this is used by the compiler so
// ensure_defined() is called. This is used by the compiler so
// that it can construct methods out of order
std::function<void(Function&)> function_creator_;

Expand Down
54 changes: 0 additions & 54 deletions torch/csrc/jit/script/module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,61 +31,7 @@ void Function::ensure_defined() {
<< " method '" << name() << "' is called recursively. "
<< "Recursive calls are not supported";
}
}

Value* Function::try_emit_call(
Graph& graph,
const SourceRange& loc,
c10::optional<NamedValue> self,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs,
std::ostream* failure_messages,
bool conv_tensors_to_nums) {
ensure_defined();
auto fn = this->graph();

auto matched_schema = tryMatchSchema(
getSchema(),
loc,
graph,
std::move(self),
args,
kwargs,
failure_messages,
conv_tensors_to_nums);
if (!matched_schema)
return nullptr;

check_single_output();
Value* fn_constant = graph.insertNode(graph.create(prim::Constant))
->output()
->setType(FunctionType::create(shared_from_this()));
matched_schema->inputs.insert(matched_schema->inputs.begin(), fn_constant);
Value* result =
graph
.insertNode(graph.create(prim::CallFunction, matched_schema->inputs))
->output()
->setType(matched_schema->return_types.at(0));
return result;
}

Value* Function::emit_call(
Graph& graph,
const SourceRange& loc,
ArrayRef<NamedValue> args,
ArrayRef<NamedValue> kwargs) {
std::stringstream failure_messages;
if (auto result = try_emit_call(
graph,
loc,
c10::nullopt,
args,
kwargs,
&failure_messages,
/*conv_tensors_to_nums=*/true)) {
return result;
}
throw ErrorReport(loc) << failure_messages.str();
}

void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/script/python_sugared_value.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <torch/csrc/jit/script/python_sugared_value.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/Layout.h>
#include <torch/csrc/jit/script/module_python.h>
#include <torch/csrc/jit/script/python_sugared_value.h>
#include <torch/csrc/jit/script/schema_matching.h>
#include <memory>
#include <sstream>
Expand Down Expand Up @@ -234,7 +234,7 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
&err,
true);
if (match) {
return MethodValue(module_, fn)
return MethodValue(module_, method_name)
.call(loc, caller, inputs, attributes, n_binders);
}
}
Expand Down
Loading