|
1 | 1 | from contextlib import contextmanager |
2 | 2 | from torch.testing import make_tensor |
3 | | -from typing import Any, List, Tuple |
| 3 | +from typing import Any, List, Tuple, Callable |
4 | 4 | import argparse |
5 | 5 | import random |
6 | 6 | import torch |
7 | 7 | import traceback |
| 8 | +import time |
8 | 9 |
|
9 | 10 | ''' |
10 | 11 | Usage: |
@@ -66,25 +67,42 @@ def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]: |
66 | 67 | torch._C._jit_pass_erase_shape_information(func.graph) |
67 | 68 | return (func, inputs) |
68 | 69 |
|
69 | | - |
70 | | -# TODO add support for timing on CPU |
71 | | -def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: |
72 | | - graph, _ = load_graph_and_inputs(ir) |
73 | | - for _ in range(warmup_runs): |
74 | | - graph(*inputs) |
75 | | - |
| 70 | +def time_cuda(fn, inputs, test_runs): |
76 | 71 | start_event = torch.cuda.Event(enable_timing=True) |
77 | 72 | end_event = torch.cuda.Event(enable_timing=True) |
78 | 73 | torch.cuda.synchronize() |
79 | 74 | start_event.record() |
80 | 75 | torch.cuda.synchronize() |
81 | 76 | for i in range(test_runs): |
82 | | - graph(*inputs) |
| 77 | + fn(*inputs) |
83 | 78 | torch.cuda.synchronize() |
84 | 79 | end_event.record() |
85 | 80 | torch.cuda.synchronize() |
86 | 81 | return start_event.elapsed_time(end_event) / test_runs |
87 | 82 |
|
| 83 | +def time_cpu(fn, inputs, test_runs): |
| 84 | + s = time.perf_counter() |
| 85 | + for _ in range(test_runs): |
| 86 | + fn(*inputs) |
| 87 | + e = time.perf_counter() |
| 88 | + return (e - s) / test_runs |
| 89 | + |
| 90 | + |
| 91 | +# TODO add support for timing on CPU |
| 92 | +def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float: |
| 93 | + graph, _ = load_graph_and_inputs(ir) |
| 94 | + for _ in range(warmup_runs): |
| 95 | + graph(*inputs) |
| 96 | + |
| 97 | + is_cpu = None |
| 98 | + for input in inputs: |
| 99 | + if isinstance(input, torch.Tensor): |
| 100 | + is_cpu = input.device.type == "cpu" |
| 101 | + break |
| 102 | + assert is_cpu != None |
| 103 | + |
| 104 | + out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs) |
| 105 | + return out |
88 | 106 |
|
89 | 107 | @contextmanager |
90 | 108 | def no_fuser(*args, **kwargs): |
@@ -122,43 +140,59 @@ def run_nvfuser(ir, inputs) -> float: |
122 | 140 | return run_test(ir, inputs) |
123 | 141 |
|
124 | 142 |
|
125 | | -def test_nvfuser(graphs: List[str], baseline_fn, nvfuser_fn): |
| 143 | +def test_runners(graphs: List[str], runners: List[Tuple[str, Callable]]): |
126 | 144 | for i, ir in enumerate(graphs): |
127 | 145 | _, inputs = load_graph_and_inputs(ir) |
128 | | - try: |
129 | | - baseline = baseline_fn(ir, inputs) |
130 | | - nvfuser = nvfuser_fn(ir, inputs) |
131 | | - improvement = (baseline / nvfuser - 1) * 100 |
132 | | - print(f" Graph {i}; baseline: {baseline:.2f} ms; nvfuser: {nvfuser:.2f} ms; improvement: {improvement:.2f}%") |
133 | | - except RuntimeError: |
134 | | - print(f" Graph {i} failed:", traceback.format_exc()) |
| 146 | + print(f"Running Graph {ir}") |
| 147 | + prev_result = None |
| 148 | + prev_runner_name = None |
| 149 | + for runner in runners: |
| 150 | + runner_name, runner_fn = runner |
| 151 | + try: |
| 152 | + result = runner_fn(ir, inputs) |
| 153 | + if prev_result: |
| 154 | + improvement = (prev_result / result - 1) * 100 |
| 155 | + print(f"{runner_name} : {result:.6f} ms improvement over {prev_runner_name}: improvement: {improvement:.2f}%") |
| 156 | + else: |
| 157 | + print(f"{runner_name} : {result:.6f} ms") |
| 158 | + prev_result = result |
| 159 | + prev_runner_name = runner_name |
| 160 | + except RuntimeError: |
| 161 | + print(f" Graph {i} failed for {runner_name} :", traceback.format_exc()) |
135 | 162 |
|
136 | 163 |
|
137 | 164 | def run(): |
138 | 165 | parser = argparse.ArgumentParser( |
139 | 166 | description="Extracts torchscript IR from log files and, optionally, benchmarks it or outputs the IR" |
140 | 167 | ) |
141 | 168 | parser.add_argument("filename", help="Filename of log file") |
142 | | - parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser against no fusion") |
143 | | - parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser against no fusion") |
| 169 | + parser.add_argument("--nvfuser", dest="nvfuser", action="store_true", help="benchmark nvfuser") |
| 170 | + parser.add_argument("--no-nvfuser", dest="nvfuser", action="store_false", help="DON'T benchmark nvfuser") |
144 | 171 | parser.set_defaults(nvfuser=False) |
145 | | - parser.add_argument("--nvfuser-nnc", dest="nvfuser_nnc", action="store_true", help="benchmark nvfuser against nnc") |
146 | | - parser.add_argument("--no-nvfuser-nnc", dest="nvfuser_nnc", action="store_false", help="DON'T benchmark nvfuser against nnc") |
147 | | - parser.set_defaults(nvfuser_nnc=False) |
| 172 | + parser.add_argument("--nnc", dest="nnc", action="store_true", help="benchmark nnc") |
| 173 | + parser.add_argument("--no-nnc", dest="nnc", action="store_false", help="DON'T benchmark nnc") |
| 174 | + parser.set_defaults(nnc=False) |
| 175 | + |
| 176 | + parser.add_argument("--baseline", dest="baseline", action="store_true", help="benchmark baseline") |
| 177 | + parser.add_argument("--no-baseline", dest="baseline", action="store_false", help="DON'T benchmark baseline") |
| 178 | + parser.set_defaults(baseline=False) |
| 179 | + |
148 | 180 | parser.add_argument("--output", dest="output", action="store_true", help="Output graph IR") |
149 | 181 | parser.add_argument("--no-output", dest="output", action="store_false", help="DON'T output graph IR") |
150 | 182 | parser.set_defaults(output=False) |
151 | 183 |
|
152 | 184 | args = parser.parse_args() |
153 | 185 | graphs = extract_ir(args.filename) |
154 | 186 |
|
| 187 | + options = [] |
| 188 | + if args.baseline: |
| 189 | + options.append(("Baseline no fusion", run_baseline_no_fusion)) |
| 190 | + if args.nnc: |
| 191 | + options.append(("NNC", run_nnc)) |
155 | 192 | if args.nvfuser: |
156 | | - print("NVFuser vs no fusion:") |
157 | | - test_nvfuser(graphs, run_baseline_no_fusion, run_nvfuser) |
| 193 | + options.append(("NVFuser", run_nvfuser)) |
158 | 194 |
|
159 | | - if args.nvfuser_nnc: |
160 | | - print("NVFuser vs NNC:") |
161 | | - test_nvfuser(graphs, run_nnc, run_nvfuser) |
| 195 | + test_runners(graphs, options) |
162 | 196 |
|
163 | 197 | if args.output: |
164 | 198 | quoted = [] |
|
0 commit comments