Skip to content

Commit 8722952

Browse files
taivufacebook-github-bot
authored andcommitted
Add benchmark for channel_shuffle operator (#43509)
Summary: Pull Request resolved: #43509 Test Plan: Imported from OSS Reviewed By: kimishpatel Differential Revision: D23299972 Pulled By: kimishpatel fbshipit-source-id: 6189d209859da5a41067eb9e8317e3bf7a0fc754
1 parent 6512032 commit 8722952

File tree

2 files changed

+63
-2
lines changed

2 files changed

+63
-2
lines changed

benchmarks/operator_benchmark/benchmark_all_other_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
import operator_benchmark as op_bench
77
from pt import ( # noqa
88
add_test, as_strided_test, batchnorm_test, binary_test, cat_test, # noqa
9-
chunk_test, conv_test, diag_test, embeddingbag_test, fill_test, # noqa
10-
gather_test, linear_test, matmul_test, pool_test, # noqa
9+
channel_shuffle_test, chunk_test, conv_test, diag_test, embeddingbag_test, # noqa
10+
fill_test, gather_test, linear_test, matmul_test, pool_test, # noqa
1111
softmax_test, hardsigmoid_test, hardswish_test, layernorm_test, # noqa
1212
groupnorm_test, instancenorm_test # noqa
1313
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import absolute_import
2+
from __future__ import division
3+
from __future__ import print_function
4+
from __future__ import unicode_literals
5+
6+
import operator_benchmark as op_bench
7+
import torch
8+
9+
10+
"""Microbenchmarks for channel_shuffle operator."""
11+
12+
13+
# Configs for PT channel_shuffle operator
14+
channel_shuffle_long_configs = op_bench.cross_product_configs(
15+
batch_size=[4, 8],
16+
channels_per_group=[32, 64],
17+
height=[32, 64],
18+
width=[32, 64],
19+
groups=[4, 8],
20+
channel_last=[True, False],
21+
tags=["long"]
22+
)
23+
24+
25+
channel_shuffle_short_configs = op_bench.config_list(
26+
attr_names=["batch_size", "channels_per_group", "height", "width", "groups"],
27+
attrs=[
28+
[2, 16, 16, 16, 2],
29+
[2, 32, 32, 32, 2],
30+
[4, 32, 32, 32, 4],
31+
[4, 64, 64, 64, 4],
32+
[8, 64, 64, 64, 8],
33+
[16, 64, 64, 64, 16],
34+
],
35+
cross_product_configs={
36+
"channel_last": [True, False],
37+
},
38+
tags=["short"]
39+
)
40+
41+
42+
class ChannelSHuffleBenchmark(op_bench.TorchBenchmarkBase):
43+
def init(self, batch_size, channels_per_group, height, width, groups, channel_last):
44+
self.groups = groups
45+
channels = channels_per_group * groups
46+
data_shape = (batch_size, channels, height, width)
47+
self.input_data = torch.rand(data_shape)
48+
if channel_last:
49+
self.input_data = self.input_data.contiguous(memory_format=torch.channels_last)
50+
self.set_module_name('channel_shuffle')
51+
52+
def forward(self):
53+
return torch.channel_shuffle(self.input_data, self.groups)
54+
55+
56+
op_bench.generate_pt_test(channel_shuffle_short_configs + channel_shuffle_long_configs,
57+
ChannelSHuffleBenchmark)
58+
59+
60+
if __name__ == "__main__":
61+
op_bench.benchmark_runner.main()

0 commit comments

Comments
 (0)