Skip to content

Commit bdb889a

Browse files
bertmaherfacebook-github-bot
authored andcommitted
[nnc] Use a descriptive name for fused kernels when profiling (#66990)
Summary: Pull Request resolved: #66990 NNC fusion groups currently show up as "TensorExpr" in the profiler, which is true but not super useful since it obscures what's actually happening in the fusion group. This change will log them as `fused_XXX` where XXX is a (length-limited) series of ops describing the subgraph, for instance `fused_mul_add` to represent a group containing `aten::mul`, `aten::add`. Test Plan: New unit test to check the output of autograd profiler. Reviewed By: dzhulgakov Differential Revision: D31762087 fbshipit-source-id: 3fadbdc67b054faa01aa42e5b6ea2c4a6bc3481f
1 parent 8beabff commit bdb889a

File tree

3 files changed

+17
-1
lines changed

3 files changed

+17
-1
lines changed

test/test_jit_fuser_te.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1962,6 +1962,18 @@ def bn_neither(i, x):
19621962
for fn in [bn, bn_no_weight, bn_no_bias, bn_neither]:
19631963
test(fn, (i, x))
19641964

1965+
def test_profiler(self):
1966+
@torch.jit.script
1967+
def test(x, y, z):
1968+
return x * y + z
1969+
1970+
args = [torch.randn(4) for _ in range(3)]
1971+
with torch.autograd.profiler.profile() as prof:
1972+
for _ in range(3):
1973+
test(*args)
1974+
self.assertIn("fused_mul_add", prof.table())
1975+
1976+
19651977
works_list = [
19661978
'__radd__',
19671979
'__rdiv__',

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,7 @@ Operation createTensorExprOp(const Node* node) {
13021302
auto kernel =
13031303
std::make_shared<tensorexpr::TensorExprKernel>(node->g(attr::Subgraph));
13041304
return [kernel](Stack& stack) {
1305-
RECORD_FUNCTION("TensorExpr", std::vector<c10::IValue>());
1305+
RECORD_FUNCTION(kernel->getKernelName(), std::vector<c10::IValue>());
13061306
kernel->run(stack);
13071307
return 0;
13081308
};

torch/csrc/jit/tensorexpr/kernel.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,10 @@ class TORCH_API TensorExprKernel {
122122
return bufferArgs_;
123123
}
124124

125+
const std::string& getKernelName() const {
126+
return codegen_->kernel_func_name();
127+
}
128+
125129
private:
126130
enum BackendType {
127131
kUninitialized,

0 commit comments

Comments
 (0)