Skip to content

Commit 67ac45e

Browse files
committed
fix the insert_guard for norm decomposation
1 parent 62447a5 commit 67ac45e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

torch/csrc/jit/passes/graph_fuser.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ struct GraphFuser {
292292
source,
293293
method_name);
294294

295-
AT_ASSERT(isFusableNorm(normalization_op));
296-
WithInsertPoint insert_guard{normalization_op};
297295
Value* new_output =
298296
SubgraphUtils::inlineGraph(nm_graph, inputs, normalization_op).at(0);
299297
return new_output;
@@ -327,6 +325,8 @@ struct GraphFuser {
327325
328326
return (input - mean) * invstd
329327
)SCRIPT";
328+
AT_ASSERT(isFusableNorm(normalization_op));
329+
WithInsertPoint insert_guard{normalization_op};
330330
Value* input = normalization_op->namedInput(attr::input);
331331
if (normalization_op->kind() == aten::batch_norm) {
332332
Value* input_dim = graph_->insert(aten::dim, {input});

0 commit comments

Comments
 (0)