We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 62447a5 commit 67ac45eCopy full SHA for 67ac45e
torch/csrc/jit/passes/graph_fuser.cpp
@@ -292,8 +292,6 @@ struct GraphFuser {
292
source,
293
method_name);
294
295
- AT_ASSERT(isFusableNorm(normalization_op));
296
- WithInsertPoint insert_guard{normalization_op};
297
Value* new_output =
298
SubgraphUtils::inlineGraph(nm_graph, inputs, normalization_op).at(0);
299
return new_output;
@@ -327,6 +325,8 @@ struct GraphFuser {
327
325
328
326
return (input - mean) * invstd
329
)SCRIPT";
+ AT_ASSERT(isFusableNorm(normalization_op));
+ WithInsertPoint insert_guard{normalization_op};
330
Value* input = normalization_op->namedInput(attr::input);
331
if (normalization_op->kind() == aten::batch_norm) {
332
Value* input_dim = graph_->insert(aten::dim, {input});
0 commit comments