Skip to content

Commit d1c4d75

Browse files
Elias Ellisonfacebook-github-bot
authored andcommitted
Add API for unexecuted op (#43629)
Summary: Pull Request resolved: #43629 We have a few places where we count the size a block / subgraph - it's nice to have a shared API to ignore operators that are not executed in the optimized graph (will be used when i add a new profiling node in PR ^^) Test Plan: Imported from OSS Reviewed By: bertmaher Differential Revision: D23358807 Pulled By: eellison fbshipit-source-id: 62c745d9025de94bdafd9f748f7c5a8574cace3f
1 parent 5da97a3 commit d1c4d75

File tree

3 files changed

+10
-3
lines changed

3 files changed

+10
-3
lines changed

torch/csrc/jit/ir/ir.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,11 @@ struct TORCH_API Node {
435435
bool isNondeterministic() const;
436436
bool hasSideEffects() const;
437437

438+
// instructions lowered by the interpreter and not run in the optimized graph
439+
bool notExecutedOp() const {
440+
return kind_ == prim::Constant || kind_ == prim::profile;
441+
}
442+
438443
// Graphs
439444

440445
// Note [Topological invariant]

torch/csrc/jit/passes/create_autodiff_subgraphs.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,7 @@ class SubgraphSlicer {
228228
size_t i = 0;
229229
for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
230230
++it) {
231-
// constants are not interpreted as instructions, ignore them
232-
i += it->kind() != prim::Constant;
231+
i += !it->notExecutedOp();
233232
if (i >= minSubgraphSize_) {
234233
return false;
235234
}

torch/csrc/jit/passes/loop_unrolling.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ bool isForLoop(Node* node) {
3535
int64_t limitedBlockSize(Block* body, int64_t limit) {
3636
auto it = body->nodes().begin();
3737
auto end = body->nodes().end();
38-
for (int64_t i = 0; i < limit; ++i, ++it) {
38+
for (int64_t i = 0; i < limit; ++it) {
3939
for (Block* subblock : it->blocks()) {
4040
i += limitedBlockSize(subblock, limit - i);
4141
}
42+
if (!it->notExecutedOp()) {
43+
++i;
44+
}
4245
if (it == end) {
4346
return i;
4447
}

0 commit comments

Comments
 (0)