Skip to content

Commit c90be03

Browse files
Elias Ellisonpytorchmergebot
authored andcommitted
Extend Graph Export to NNC, extend script to support CPU (#74076)
Summary: Pull Request resolved: #74076 Extends the repro script to cpu and NNC. As in file: Usage: ``` 1. Run your script and pipe into a log file PYTORCH_JIT_LOG_LEVEL=">>tensorexpr_fuser" python3 my_test.py &> log.txt 2. Run log_extract: log_extract.py log.txt --baseline --nnc ``` Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D34946883 Pulled By: eellison fbshipit-source-id: 644012dbbca0b490820ef83e761c06b0dd009e52 (cherry picked from commit 5256c8f)
1 parent 9c4a637 commit c90be03

File tree

2 files changed

+65
-27
lines changed

2 files changed

+65
-27
lines changed

scripts/jit/log_extract.py

Lines changed: 61 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from contextlib import contextmanager
22
from torch.testing import make_tensor
3-
from typing import Any, List, Tuple
3+
from typing import Any, List, Tuple, Callable
44
import argparse
55
import random
66
import torch
77
import traceback
8+
import time
89

910
'''
1011
Usage:
@@ -66,25 +67,42 @@ def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
6667
torch._C._jit_pass_erase_shape_information(func.graph)
6768
return (func, inputs)
6869

69-
70-
# TODO add support for timing on CPU
71-
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
72-
graph, _ = load_graph_and_inputs(ir)
73-
for _ in range(warmup_runs):
74-
graph(*inputs)
75-
70+
def time_cuda(fn, inputs, test_runs):
7671
start_event = torch.cuda.Event(enable_timing=True)
7772
end_event = torch.cuda.Event(enable_timing=True)
7873
torch.cuda.synchronize()
7974
start_event.record()
8075
torch.cuda.synchronize()
8176
for i in range(test_runs):
82-
graph(*inputs)
77+
fn(*inputs)
8378
torch.cuda.synchronize()
8479
end_event.record()
8580
torch.cuda.synchronize()
8681
return start_event.elapsed_time(end_event) / test_runs
8782

83+
def time_cpu(fn, inputs, test_runs):
84+
s = time.perf_counter()
85+
for _ in range(test_runs):
86+
fn(*inputs)
87+
e = time.perf_counter()
88+
return (e - s) / test_runs
89+
90+
91+
# TODO add support for timing on CPU
92+
def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
93+
graph, _ = load_graph_and_inputs(ir)
94+
for _ in range(warmup_runs):
95+
graph(*inputs)
96+
97+
is_cpu = None
98+
for input in inputs:
99+
if isinstance(input, torch.Tensor):
100+
is_cpu = input.device.type == "cpu"
101+
break
102+
assert is_cpu != None
103+
104+
out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
105+
return out
88106

89107
@contextmanager
90108
def no_fuser(*args, **kwargs):
@@ -122,43 +140,59 @@ def run_nvfuser(ir, inputs) -> float:
122140
return run_test(ir, inputs)
123141

124142

125-
def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn):
143+
def test_runners(graphs: List[str], runners: List[Tuple[str, Callable]]):
126144
for i, ir in enumerate(graphs):
127145
_, inputs = load_graph_and_inputs(ir)
128-
try:
129-
baseline = baseline_fn(ir, inputs)
130-
nvfuser = nvfuser_fn(ir, inputs)
131-
improvement = (baseline / nvfuser - 1) * 100
132-
print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%")
133-
except RuntimeError:
134-
print(f" Graph {i} failed:", traceback.format_exc())
146+
print(f"Running Graph {ir}")
147+
prev_result = None
148+
prev_runner_name = None
149+
for runner in runners:
150+
runner_name, runner_fn = runner
151+
try:
152+
result = runner_fn(ir, inputs)
153+
if prev_result:
154+
improvement = (prev_result / result - 1) * 100
155+
print(f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%")
156+
else:
157+
print(f"{runner_name} : {result:.6f} ms")
158+
prev_result = result
159+
prev_runner_name = runner_name
160+
except RuntimeError:
161+
print(f" Graph {i} failed for {runner_name} :", traceback.format_exc())
135162

136163

137164
def run():
138165
parser = argparse.ArgumentParser(
139166
description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR"
140167
)
141168
parser.add_argument("filename", help="Filename of log file")
142-
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion")
143-
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion")
169+
parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser")
170+
parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser")
144171
parser.set_defaults(nvfuser=False)
145-
parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc")
146-
parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc")
147-
parser.set_defaults(nvfuser_nnc=False)
172+
parser.add_argument("--nnc", dest="nnc", action="store_true", help="benchmark nnc")
173+
parser.add_argument("--no-nnc", dest="nnc", action="store_false", help="DON'T benchmark nnc")
174+
parser.set_defaults(nnc=False)
175+
176+
parser.add_argument("--baseline", dest="baseline", action="store_true", help="benchmark baseline")
177+
parser.add_argument("--no-baseline", dest="baseline", action="store_false", help="DON'T benchmark baseline")
178+
parser.set_defaults(baseline=False)
179+
148180
parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR")
149181
parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR")
150182
parser.set_defaults(output=False)
151183

152184
args = parser.parse_args()
153185
graphs = extract_ir(args.filename)
154186

187+
options = []
188+
if args.baseline:
189+
options.append(("Baseline no fusion", run_baseline_no_fusion))
190+
if args.nnc:
191+
options.append(("NNC", run_nnc))
155192
if args.nvfuser:
156-
print("NVFuser vs no fusion:")
157-
test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser)
193+
options.append(("NVFuser", run_nvfuser))
158194

159-
if args.nvfuser_nnc:
160-
print("NVFuser vs NNC:")
161-
test_nvfuser(graphs, run_nnc, run_nvfuser)
195+
test_runners(graphs, options)
162196

163197
if args.output:
164198
quoted = []

torch/csrc/jit/passes/tensorexpr_fuser.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,10 @@ class TensorExprFuser {
745745
}
746746
// Cleanup the subgraph from duplicated constants while we're at it.
747747
ConstantPooling(subgraph);
748+
749+
if (GRAPH_DEBUG_ENABLED) {
750+
GRAPH_EXPORT("", subgraph);
751+
}
748752
return false;
749753
}
750754

0 commit comments

Comments
 (0)