@@ -9,41 +9,48 @@ namespace prim {
99using namespace ::c10::prim;
1010}
1111
12- void inlineCalls (Block* block) {
13- Node* cur = block->nodes ().front ();
14- Node* end = block->return_node ();
15-
16- while (cur != end) {
17- auto next = cur->next ();
18- for (auto b : cur->blocks ()) {
19- inlineCalls (b);
20- }
21- if (cur->kind () == prim::CallFunction) {
22- AT_ASSERT (cur->inputs ().at (0 )->node ()->kind () == prim::Constant);
23- auto function_constant = cur->inputs ().at (0 )->node ();
24- auto fun_type =
25- function_constant->output ()->type ()->expect <FunctionType>();
26- auto graph = fun_type->function ()->graph ();
27-
28- auto old_output = cur->outputs ();
29- // slice function ptr value
30- auto inputs = cur->inputs ().slice (1 );
31- WithInsertPoint guard (next);
32- auto new_output =
33- inlineCallTo (*cur->owningGraph (), *graph.get (), inputs).at (0 );
34- if (old_output.at (0 )->hasUniqueName ()) {
35- auto name = old_output.at (0 )->uniqueName ();
36- new_output->setUniqueName (name);
37- }
12+ static void replace (
13+ Node* to_replace,
14+ const std::shared_ptr<script::Function>& fn,
15+ at::ArrayRef<Value*> inputs) {
16+ WithInsertPoint guard (to_replace);
17+ auto new_output =
18+ inlineCallTo (*to_replace->owningGraph (), *fn->graph (), inputs).at (0 );
19+ if (to_replace->output ()->hasUniqueName ()) {
20+ new_output->setUniqueName (to_replace->output ()->uniqueName ());
21+ }
22+ to_replace->output ()->replaceAllUsesWith (new_output);
23+ }
3824
39- old_output.at (0 )->replaceAllUsesWith (new_output);
40- next = cur->next ();
41- cur->destroy ();
42- if (!function_constant->hasUses ()) {
43- function_constant->destroy ();
44- }
25+ void inlineCalls (Block* block) {
26+ for (auto it = block->nodes ().begin (), end = block->nodes ().end ();
27+ it != end;) {
28+ Node* cur = *it++;
29+ switch (cur->kind ()) {
30+ case prim::CallFunction: {
31+ AT_ASSERT (cur->inputs ().at (0 )->node ()->kind () == prim::Constant);
32+ auto function_constant = cur->inputs ().at (0 )->node ();
33+ auto fun_type =
34+ function_constant->output ()->type ()->expect <FunctionType>();
35+ replace (cur, fun_type->function (), cur->inputs ().slice (1 ));
36+ cur->destroy ();
37+ if (!function_constant->hasUses ()) {
38+ function_constant->destroy ();
39+ }
40+ } break ;
41+ case prim::CallMethod: {
42+ const std::string& name = cur->s (attr::name);
43+ auto function =
44+ cur->inputs ().at (0 )->type ()->expect <ClassType>()->getMethod (name);
45+ replace (cur, function, cur->inputs ());
46+ cur->destroy ();
47+ } break ;
48+ default : {
49+ for (auto b : cur->blocks ()) {
50+ inlineCalls (b);
51+ }
52+ } break ;
4553 }
46- cur = next;
4754 }
4855}
4956
0 commit comments