-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathkernel_utils.py
More file actions
50 lines (44 loc) · 1.63 KB
/
kernel_utils.py
File metadata and controls
50 lines (44 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from scipy.stats import gmean
def benchmark(fn, warmup=100, rep=200, quantiles=None, fast_flush=True):
# https://github.com/nox-410/tvm.tl/blob/tl/python/tvm/tl/utils.py#L144
fn()
torch.cuda.synchronize()
if fast_flush:
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda")
else:
cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda")
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.cuda.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)]
# Warm-up
for _ in range(n_warmup):
fn()
# Benchmark
for i in range(n_repeat):
cache.zero_()
start_event[i].record()
fn()
end_event[i].record()
torch.cuda.synchronize()
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=torch.float
)
times_list = [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
if quantiles is not None:
ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist()
if len(ret) == 1:
ret = ret[0]
return ret
return gmean(times_list)