Skip to content

Commit 6e657c5

Browse files
zdevitofacebook-github-bot
authored andcommitted
Add CallMethod, inline eagerly (#21116)
Summary: Pull Request resolved: #21116 ghimport-source-id: 3c47e33 Reviewed By: eellison Differential Revision: D15552816 Pulled By: zdevito fbshipit-source-id: 708fe87439d94117dca0a26c98f0917f497f718f
1 parent 0f58d20 commit 6e657c5

File tree

13 files changed

+166
-159
lines changed

13 files changed

+166
-159
lines changed

test/test_jit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16250,7 +16250,7 @@ def _xor(): # noqa: E306
1625016250
_or, _xor]:
1625116251
self.checkScript(func, ())
1625216252

16253-
with self.assertRaisesRegex(RuntimeError, "because it does not define a __add__"):
16253+
with self.assertRaisesRegex(RuntimeError, "__add__ method"):
1625416254
@torch.jit.script
1625516255
def test():
1625616256
return Foo(torch.tensor(1)) + Foo(torch.tensor(1))

torch/csrc/jit/ir.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,11 @@ void Node::printAttributes(std::ostream& out, bool ignore_subgraph = false)
201201
}
202202

203203
SourceRange Node::sourceRange() const {
204-
if(source_range_) {
205-
return *source_range_;
206-
}
207-
std::stringstream ss;
208-
return SourceRange(ss.str());
204+
if (source_range_) {
205+
return *source_range_;
206+
}
207+
std::stringstream ss;
208+
return SourceRange(ss.str());
209209
}
210210

211211
static std::ostream& indent(std::ostream& out, size_t level) {
@@ -236,7 +236,6 @@ std::ostream& Node::print(
236236

237237
groups->push_back(this);
238238
} else {
239-
240239
out << kind().toQualString();
241240
if (hasAttributes()) {
242241
printAttributes(out);
@@ -1429,6 +1428,30 @@ Node* Graph::createGetAttr(Value* obj, const std::string& field) {
14291428
return n;
14301429
}
14311430

1431+
Value* Graph::insertFunctionCall(
1432+
std::shared_ptr<script::Function> callee,
1433+
script::MatchedSchema& matched) {
1434+
Value* fn_constant = insertNode(create(prim::Constant))
1435+
->output()
1436+
->setType(FunctionType::create(std::move(callee)));
1437+
std::vector<Value*> inputs = {fn_constant};
1438+
inputs.insert(inputs.end(), matched.inputs.begin(), matched.inputs.end());
1439+
Value* result = insertNode(create(prim::CallFunction, inputs))
1440+
->output()
1441+
->setType(matched.return_types.at(0));
1442+
return result;
1443+
}
1444+
1445+
Value* Graph::insertMethodCall(
1446+
std::string method_name,
1447+
script::MatchedSchema& matched) {
1448+
Value* result = insertNode(create(prim::CallMethod, matched.inputs))
1449+
->s_(attr::name, std::move(method_name))
1450+
->output()
1451+
->setType(matched.return_types.at(0));
1452+
return result;
1453+
}
1454+
14321455
Node* Graph::createClone(
14331456
Node* n,
14341457
const std::function<Value*(Value*)>& value_map,

torch/csrc/jit/ir.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,11 @@ namespace aten {
7979
using namespace ::c10::aten;
8080
}
8181

82+
namespace script {
83+
struct Function;
84+
struct MatchedSchema;
85+
} // namespace script
86+
8287
// Graph represents one "function" of computation.
8388
// It uses a simple ownership model where the graph owns all the nodes inside
8489
// it. All references inside the graph are raw pointers. Destroying the Graph
@@ -1093,6 +1098,13 @@ struct Graph {
10931098
return insertNode(createGetAttr(obj, field))->output();
10941099
}
10951100

1101+
TORCH_API Value* insertFunctionCall(
1102+
std::shared_ptr<script::Function> callee,
1103+
script::MatchedSchema& matched);
1104+
TORCH_API Value* insertMethodCall(
1105+
std::string method_name,
1106+
script::MatchedSchema& matched);
1107+
10961108
// Note: defined in python_ir.cpp and can be used only in python extension
10971109
Node* createPythonOp(
10981110
THPObjectPtr&& pyobj,

torch/csrc/jit/passes/inliner.cpp

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,48 @@ namespace prim {
99
using namespace ::c10::prim;
1010
}
1111

12-
void inlineCalls(Block* block) {
13-
Node* cur = block->nodes().front();
14-
Node* end = block->return_node();
15-
16-
while (cur != end) {
17-
auto next = cur->next();
18-
for (auto b : cur->blocks()) {
19-
inlineCalls(b);
20-
}
21-
if (cur->kind() == prim::CallFunction) {
22-
AT_ASSERT(cur->inputs().at(0)->node()->kind() == prim::Constant);
23-
auto function_constant = cur->inputs().at(0)->node();
24-
auto fun_type =
25-
function_constant->output()->type()->expect<FunctionType>();
26-
auto graph = fun_type->function()->graph();
27-
28-
auto old_output = cur->outputs();
29-
// slice function ptr value
30-
auto inputs = cur->inputs().slice(1);
31-
WithInsertPoint guard(next);
32-
auto new_output =
33-
inlineCallTo(*cur->owningGraph(), *graph.get(), inputs).at(0);
34-
if (old_output.at(0)->hasUniqueName()) {
35-
auto name = old_output.at(0)->uniqueName();
36-
new_output->setUniqueName(name);
37-
}
12+
static void replace(
13+
Node* to_replace,
14+
const std::shared_ptr<script::Function>& fn,
15+
at::ArrayRef<Value*> inputs) {
16+
WithInsertPoint guard(to_replace);
17+
auto new_output =
18+
inlineCallTo(*to_replace->owningGraph(), *fn->graph(), inputs).at(0);
19+
if (to_replace->output()->hasUniqueName()) {
20+
new_output->setUniqueName(to_replace->output()->uniqueName());
21+
}
22+
to_replace->output()->replaceAllUsesWith(new_output);
23+
}
3824

39-
old_output.at(0)->replaceAllUsesWith(new_output);
40-
next = cur->next();
41-
cur->destroy();
42-
if (!function_constant->hasUses()) {
43-
function_constant->destroy();
44-
}
25+
void inlineCalls(Block* block) {
26+
for (auto it = block->nodes().begin(), end = block->nodes().end();
27+
it != end;) {
28+
Node* cur = *it++;
29+
switch (cur->kind()) {
30+
case prim::CallFunction: {
31+
AT_ASSERT(cur->inputs().at(0)->node()->kind() == prim::Constant);
32+
auto function_constant = cur->inputs().at(0)->node();
33+
auto fun_type =
34+
function_constant->output()->type()->expect<FunctionType>();
35+
replace(cur, fun_type->function(), cur->inputs().slice(1));
36+
cur->destroy();
37+
if (!function_constant->hasUses()) {
38+
function_constant->destroy();
39+
}
40+
} break;
41+
case prim::CallMethod: {
42+
const std::string& name = cur->s(attr::name);
43+
auto function =
44+
cur->inputs().at(0)->type()->expect<ClassType>()->getMethod(name);
45+
replace(cur, function, cur->inputs());
46+
cur->destroy();
47+
} break;
48+
default: {
49+
for (auto b : cur->blocks()) {
50+
inlineCalls(b);
51+
}
52+
} break;
4553
}
46-
cur = next;
4754
}
4855
}
4956

torch/csrc/jit/script/builtin_functions.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
#include <torch/csrc/jit/script/builtin_functions.h>
21
#include <torch/csrc/api/include/torch/jit.h>
32
#include <torch/csrc/jit/code_template.h>
3+
#include <torch/csrc/jit/script/builtin_functions.h>
44
#include <torch/csrc/jit/script/resolver.h>
55

66
namespace torch {
@@ -38,8 +38,9 @@ def _${name}(x: BroadcastingList${Length}[${Scalar}]) -> List[${Scalar}]:
3838
)SCRIPT");
3939

4040
struct BuiltinFunctionRegistry {
41-
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
42-
const static std::vector<Function*> empty;
41+
const std::vector<std::shared_ptr<Function>>& getAllBuiltinFunctionsFor(
42+
Symbol name) {
43+
const static std::vector<std::shared_ptr<Function>> empty;
4344
// when initializing the builtin function library, we will re-enter
4445
// getAllBuiltinFunctionsFor since it is called in the compiler to
4546
// lookup builtins and initializing the builtin functions calls the
@@ -68,7 +69,7 @@ struct BuiltinFunctionRegistry {
6869
cu->define(source, script::nativeResolver(), /*self=*/nullptr);
6970
for (auto& method : cu->get_functions()) {
7071
builtins_by_name[Symbol::fromQualString("aten::" + method->name())]
71-
.push_back(method.get());
72+
.push_back(method);
7273
}
7374
}
7475
void loadBuiltinFunctions() {
@@ -98,10 +99,12 @@ struct BuiltinFunctionRegistry {
9899
enum { UNINITIALIZED, INTIIALIZING, INITIALIZED } state = UNINITIALIZED;
99100
std::recursive_mutex mutex;
100101
std::vector<std::shared_ptr<CompilationUnit>> modules;
101-
std::unordered_map<Symbol, std::vector<Function*>> builtins_by_name;
102+
std::unordered_map<Symbol, std::vector<std::shared_ptr<Function>>>
103+
builtins_by_name;
102104
};
103105

104-
const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name) {
106+
const std::vector<std::shared_ptr<Function>>& getAllBuiltinFunctionsFor(
107+
Symbol name) {
105108
static BuiltinFunctionRegistry registry;
106109
return registry.getAllBuiltinFunctionsFor(name);
107110
}

torch/csrc/jit/script/builtin_functions.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ namespace torch {
77
namespace jit {
88
namespace script {
99

10-
TORCH_API const std::vector<Function*>& getAllBuiltinFunctionsFor(Symbol name);
11-
10+
TORCH_API const std::vector<std::shared_ptr<Function>>&
11+
getAllBuiltinFunctionsFor(Symbol name);
1212
}
1313
} // namespace jit
1414
} // namespace torch

torch/csrc/jit/script/compilation_unit.h

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -121,25 +121,6 @@ struct TORCH_API Function : public std::enable_shared_from_this<Function> {
121121
return executor_;
122122
}
123123

124-
// returns nullptr and fills in failure_messages if the callee does not
125-
// match the functions schema
126-
127-
// TODO: defined in module.cpp, move to compilation_unit.cpp
128-
Value* try_emit_call(
129-
Graph& graph,
130-
const SourceRange& loc,
131-
c10::optional<NamedValue> self,
132-
ArrayRef<NamedValue> args,
133-
ArrayRef<NamedValue> kwargs,
134-
std::ostream* failure_messages,
135-
bool conv_tensors_to_nums);
136-
137-
Value* emit_call(
138-
Graph& graph,
139-
const SourceRange& loc,
140-
ArrayRef<NamedValue> args,
141-
ArrayRef<NamedValue> kwargs);
142-
143124
private:
144125
static FunctionSchema defaultSchemaFor(const Function& function) {
145126
std::vector<Argument> args;
@@ -167,7 +148,7 @@ struct TORCH_API Function : public std::enable_shared_from_this<Function> {
167148
std::once_flag executor_init_;
168149

169150
// an optional function that actually creates the method when
170-
// emit_call_to(this,...) is first called. this is used by the compiler so
151+
// ensure_defined() is called. This is used by the compiler so
171152
// that it can construct methods out of order
172153
std::function<void(Function&)> function_creator_;
173154

torch/csrc/jit/script/module.cpp

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,61 +31,7 @@ void Function::ensure_defined() {
3131
<< " method '" << name() << "' is called recursively. "
3232
<< "Recursive calls are not supported";
3333
}
34-
}
35-
36-
Value* Function::try_emit_call(
37-
Graph& graph,
38-
const SourceRange& loc,
39-
c10::optional<NamedValue> self,
40-
ArrayRef<NamedValue> args,
41-
ArrayRef<NamedValue> kwargs,
42-
std::ostream* failure_messages,
43-
bool conv_tensors_to_nums) {
44-
ensure_defined();
45-
auto fn = this->graph();
46-
47-
auto matched_schema = tryMatchSchema(
48-
getSchema(),
49-
loc,
50-
graph,
51-
std::move(self),
52-
args,
53-
kwargs,
54-
failure_messages,
55-
conv_tensors_to_nums);
56-
if (!matched_schema)
57-
return nullptr;
58-
5934
check_single_output();
60-
Value* fn_constant = graph.insertNode(graph.create(prim::Constant))
61-
->output()
62-
->setType(FunctionType::create(shared_from_this()));
63-
matched_schema->inputs.insert(matched_schema->inputs.begin(), fn_constant);
64-
Value* result =
65-
graph
66-
.insertNode(graph.create(prim::CallFunction, matched_schema->inputs))
67-
->output()
68-
->setType(matched_schema->return_types.at(0));
69-
return result;
70-
}
71-
72-
Value* Function::emit_call(
73-
Graph& graph,
74-
const SourceRange& loc,
75-
ArrayRef<NamedValue> args,
76-
ArrayRef<NamedValue> kwargs) {
77-
std::stringstream failure_messages;
78-
if (auto result = try_emit_call(
79-
graph,
80-
loc,
81-
c10::nullopt,
82-
args,
83-
kwargs,
84-
&failure_messages,
85-
/*conv_tensors_to_nums=*/true)) {
86-
return result;
87-
}
88-
throw ErrorReport(loc) << failure_messages.str();
8935
}
9036

9137
void Module::to(at::Device device, at::ScalarType dtype, bool non_blocking) {

torch/csrc/jit/script/python_sugared_value.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include <torch/csrc/jit/script/python_sugared_value.h>
21
#include <torch/csrc/Dtype.h>
32
#include <torch/csrc/Layout.h>
43
#include <torch/csrc/jit/script/module_python.h>
4+
#include <torch/csrc/jit/script/python_sugared_value.h>
55
#include <torch/csrc/jit/script/schema_matching.h>
66
#include <memory>
77
#include <sstream>
@@ -234,7 +234,7 @@ std::shared_ptr<SugaredValue> OverloadedMethodValue::call(
234234
&err,
235235
true);
236236
if (match) {
237-
return MethodValue(module_, fn)
237+
return MethodValue(module_, method_name)
238238
.call(loc, caller, inputs, attributes, n_binders);
239239
}
240240
}

0 commit comments

Comments
 (0)