Commit e84bf88
[ATen][CUDA] Implement 128 bit vectorization v2 (#145746)
This is a re-base PR to my previous one #141959.
Description from the original PR:
This PR implements 128-bit vectorization. It improves the performance of contiguous elementwise ops by 4-10% on Hopper H100.
<details>
<summary>The benchmark code used </summary>
```Python
import time
import torch
from torch.profiler import profile, ProfilerActivity
def benchmark(function, dtype=torch.float32, check_numerics=True, print_profile=False):
device = torch.device("cuda")
shapes = []
for p in range(24, 30):
shape = 1<<p
shapes.append(shape)
for shape in shapes:
for _ in range(6):
x = torch.randn(shape, device=device, dtype=dtype)
y = function(x)
if print_profile:
x = torch.randn(shape, device=device, dtype=dtype)
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
y = function(x)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
x = torch.randn(shape, device=device, dtype=dtype)
torch.cuda.synchronize()
t1 = time.perf_counter()
for _ in range(6):
y = function(x)
torch.cuda.synchronize()
t2 = time.perf_counter()
perf_time = (t2 - t1) / 6
print(f"{function.__name__}, {dtype}, {shape}, {perf_time}")
if check_numerics:
x_cpu = x.cpu()
y_cpu = function(x_cpu).cuda()
try:
torch.testing.assert_allclose(y_cpu, y)
except AssertionError as error:
print("An exception occurred:", error)
def main():
ops = [
torch.relu,
torch.sigmoid,
torch.tanh,
torch.nn.functional.gelu,
torch.sin,
torch.exp,
]
dtypes = [
torch.float16,
torch.bfloat16,
torch.float32,
]
for op in ops:
for dtype in dtypes:
benchmark(op, dtype=dtype)
torch.cuda.empty_cache()
if __name__ == "__main__":
main()
```
</details>
<details>
<summary> Results </summary>
| op | dtype | size | time after | time before | % improvement |
| ---- | ---- | ---- | ---- | ---- | ---- |
| relu | torch.float16 | 33554432 | 4.84E-05 | 5.06E-05 | 4.66296539127052 |
| relu | torch.float16 | 67108864 | 9.22E-05 | 9.64E-05 | 4.56491432752297 |
| relu | torch.float16 | 134217728 | 0.000180343495837102 | 0.000187981834945579 | 4.23543919508829 |
| relu | torch.float16 | 268435456 | 0.000355071155354381 | 0.000370856161074092 | 4.44558942107169 |
| relu | torch.float16 | 536870912 | 0.000704489842367669 | 0.000736006341564159 | 4.47366268483987 |
| relu | torch.bfloat16 | 16777216 | 3.03E-05 | 3.04E-05 | 0.166504085842689 |
| relu | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.45848238875716 |
| relu | torch.bfloat16 | 67108864 | 9.32E-05 | 9.65E-05 | 3.56122651631445 |
| relu | torch.bfloat16 | 134217728 | 0.000180805509444326 | 0.000187998676362137 | 3.97840029317567 |
| relu | torch.bfloat16 | 268435456 | 0.000356242332297067 | 0.000371279485989362 | 4.22104627356745 |
| relu | torch.bfloat16 | 536870912 | 0.000708114336399982 | 0.000736773828975856 | 4.04729732229083 |
| relu | torch.float32 | 16777216 | 5.61E-05 | 5.61E-05 | 0.0442587268354941 |
| relu | torch.float32 | 33554432 | 9.33E-05 | 9.30E-05 | -0.259070913799022 |
| relu | torch.float32 | 67108864 | 0.000181321326332788 | 0.000181289506144822 | -0.0175490597877115 |
| relu | torch.float32 | 134217728 | 0.000356896334172537 | 0.000356570177245885 | -0.0913870206618981 |
| relu | torch.float32 | 268435456 | 0.000709421835684528 | 0.000707465515006334 | -0.275762681635911 |
| relu | torch.float32 | 536870912 | 0.00141372415237129 | 0.00141036518228551 | -0.237597276678471 |
| sigmoid | torch.float16 | 16777216 | 3.10E-05 | 3.16E-05 | 2.10012593866895 |
| sigmoid | torch.float16 | 33554432 | 4.91E-05 | 5.23E-05 | 6.37710600666122 |
| sigmoid | torch.float16 | 67108864 | 9.30E-05 | 0.000100057009452333 | 7.61866144555331 |
| sigmoid | torch.float16 | 134217728 | 0.000180928347011407 | 0.000194982004662355 | 7.76752669390248 |
| sigmoid | torch.float16 | 268435456 | 0.000355658994521946 | 0.00038468533117945 | 8.16128288742412 |
| sigmoid | torch.float16 | 536870912 | 0.000705982849467546 | 0.000764021339515845 | 8.22094900634937 |
| sigmoid | torch.bfloat16 | 16777216 | 3.08E-05 | 3.17E-05 | 2.90965915673149 |
| sigmoid | torch.bfloat16 | 33554432 | 4.87E-05 | 5.24E-05 | 7.63503884668234 |
| sigmoid | torch.bfloat16 | 67108864 | 9.33E-05 | 0.000100019678939134 | 7.21238137428013 |
| sigmoid | torch.bfloat16 | 134217728 | 0.000180786165098349 | 0.000194868014659733 | 7.78922964250206 |
| sigmoid | torch.bfloat16 | 268435456 | 0.000355564659306159 | 0.000384909333661199 | 8.25297835063321 |
| sigmoid | torch.bfloat16 | 536870912 | 0.000705831005082776 | 0.000764102345177283 | 8.2557070566308 |
| sigmoid | torch.float32 | 16777216 | 4.93E-05 | 5.65E-05 | 14.5314136197766 |
| sigmoid | torch.float32 | 33554432 | 9.32E-05 | 9.31E-05 | -0.120169865610833 |
| sigmoid | torch.float32 | 67108864 | 0.000181328505277634 | 0.000180455681402236 | -0.481349512069855 |
| sigmoid | torch.float32 | 134217728 | 0.000357362829769651 | 0.000356093340087682 | -0.35523831137877 |
| sigmoid | torch.float32 | 268435456 | 0.000708921831877281 | 0.000707052337626616 | -0.263709504574663 |
| sigmoid | torch.float32 | 536870912 | 0.00141358317341656 | 0.0014090768333214 | -0.318788464654745 |
| tanh | torch.float16 | 16777216 | 3.03E-05 | 3.03E-05 | -0.0912564658661808 |
| tanh | torch.float16 | 33554432 | 4.90E-05 | 5.07E-05 | 3.46644442974484 |
| tanh | torch.float16 | 67108864 | 9.30E-05 | 9.68E-05 | 3.99871369815531 |
| tanh | torch.float16 | 134217728 | 0.00018052199933057 | 0.000188717152923346 | 4.53969799978138 |
| tanh | torch.float16 | 268435456 | 0.000355684508879979 | 0.000373026006855071 | 4.8755280430115 |
| tanh | torch.float16 | 536870912 | 0.000706660988119741 | 0.000740105014604827 | 4.73268328765002 |
| tanh | torch.bfloat16 | 16777216 | 2.99E-05 | 3.03E-05 | 1.21049563135981 |
| tanh | torch.bfloat16 | 33554432 | 4.89E-05 | 5.06E-05 | 3.48836101041744 |
| tanh | torch.bfloat16 | 67108864 | 9.28E-05 | 9.69E-05 | 4.39944918036626 |
| tanh | torch.bfloat16 | 134217728 | 0.000180710999605556 | 0.000189167990659674 | 4.67984299382829 |
| tanh | torch.bfloat16 | 268435456 | 0.000356062994493792 | 0.000372666652159144 | 4.66312363882606 |
| tanh | torch.bfloat16 | 536870912 | 0.000707100164921333 | 0.000740134331863374 | 4.67178040408393 |
| tanh | torch.float32 | 16777216 | 5.61E-05 | 5.64E-05 | 0.439595755746353 |
| tanh | torch.float32 | 33554432 | 9.31E-05 | 9.31E-05 | 0.00287633090228212 |
| tanh | torch.float32 | 67108864 | 0.000181465332085888 | 0.000180895323865116 | -0.31411411437098 |
| tanh | torch.float32 | 134217728 | 0.000356963835656643 | 0.000356073161431899 | -0.249513854283251 |
| tanh | torch.float32 | 268435456 | 0.000709201170442005 | 0.00070707315656667 | -0.300057862849997 |
| tanh | torch.float32 | 536870912 | 0.00141367283261692 | 0.00141030051357423 | -0.238550176877922 |
| gelu | torch.float16 | 16777216 | 2.73E-05 | 3.17E-05 | 15.921079070745 |
| gelu | torch.float16 | 33554432 | 5.06E-05 | 5.55E-05 | 9.76345374333098 |
| gelu | torch.float16 | 67108864 | 9.65E-05 | 0.000106600326641152 | 10.4308039074712 |
| gelu | torch.float16 | 134217728 | 0.000187776672343413 | 0.000208565829476962 | 11.0712139447915 |
| gelu | torch.float16 | 268435456 | 0.000370216167842348 | 0.000412251994324227 | 11.3544005187205 |
| gelu | torch.float16 | 536870912 | 0.000737301345604161 | 0.000819394170927505 | 11.1342296895002 |
| gelu | torch.bfloat16 | 16777216 | 3.02E-05 | 3.08E-05 | 1.78405479367653 |
| gelu | torch.bfloat16 | 33554432 | 5.13E-05 | 5.69E-05 | 10.9929393318302 |
| gelu | torch.bfloat16 | 67108864 | 9.76E-05 | 0.00010968199543034 | 12.3420807512356 |
| gelu | torch.bfloat16 | 134217728 | 0.000189661824454864 | 0.000214487663470209 | 13.0895287371091 |
| gelu | torch.bfloat16 | 268435456 | 0.000374197009174774 | 0.000423670164309442 | 13.2211519391275 |
| gelu | torch.bfloat16 | 536870912 | 0.000743675006863972 | 0.000842577001700799 | 13.299088166737 |
| gelu | torch.float32 | 16777216 | 5.06E-05 | 5.04E-05 | -0.413385894716413 |
| gelu | torch.float32 | 33554432 | 9.31E-05 | 9.32E-05 | 0.134157041722546 |
| gelu | torch.float32 | 67108864 | 0.000181480175039421 | 0.000180836669945469 | -0.354586992112075 |
| gelu | torch.float32 | 134217728 | 0.000356874331676712 | 0.000356305002545317 | -0.159532104402047 |
| gelu | torch.float32 | 268435456 | 0.000708909006789327 | 0.000706991491218408 | -0.270488250615287 |
| gelu | torch.float32 | 536870912 | 0.00141321367118508 | 0.00140937082081412 | -0.271922813181618 |
| sin | torch.float16 | 16777216 | 3.04E-05 | 3.11E-05 | 2.21834939018859 |
| sin | torch.float16 | 33554432 | 4.85E-05 | 5.23E-05 | 7.72165512511596 |
| sin | torch.float16 | 67108864 | 9.31E-05 | 9.98E-05 | 7.24947099480072 |
| sin | torch.float16 | 134217728 | 0.000180371008658161 | 0.000194791161144773 | 7.99471744039613 |
| sin | torch.float16 | 268435456 | 0.000355454161763191 | 0.000384903668115536 | 8.28503630574026 |
| sin | torch.float16 | 536870912 | 0.000705183832906187 | 0.000764360166310022 | 8.39161799270973 |
| sin | torch.bfloat16 | 16777216 | 3.11E-05 | 3.10E-05 | -0.257677954940036 |
| sin | torch.bfloat16 | 33554432 | 4.89E-05 | 5.24E-05 | 7.34808420323539 |
| sin | torch.bfloat16 | 67108864 | 9.26E-05 | 0.000100248667877167 | 8.22347488801205 |
| sin | torch.bfloat16 | 134217728 | 0.000180674154156198 | 0.00019567032965521 | 8.30012215584937 |
| sin | torch.bfloat16 | 268435456 | 0.000355360486234228 | 0.000386023331278314 | 8.62865913118873 |
| sin | torch.bfloat16 | 536870912 | 0.00070483615854755 | 0.000766805159704139 | 8.79197248964745 |
| sin | torch.float32 | 16777216 | 5.67E-05 | 5.64E-05 | -0.441348534920039 |
| sin | torch.float32 | 33554432 | 9.34E-05 | 9.30E-05 | -0.496458540364117 |
| sin | torch.float32 | 67108864 | 0.000181706990891447 | 0.000180556671693921 | -0.633062708199702 |
| sin | torch.float32 | 134217728 | 0.000356894995396336 | 0.000356046327700218 | -0.237791985616354 |
| sin | torch.float32 | 268435456 | 0.000708777321657787 | 0.000707602652255446 | -0.165731798471427 |
| sin | torch.float32 | 536870912 | 0.00141263716310884 | 0.00140912582476934 | -0.248566187496451 |
| exp | torch.float16 | 16777216 | 3.00E-05 | 3.04E-05 | 1.40099098901014 |
| exp | torch.float16 | 33554432 | 4.86E-05 | 5.03E-05 | 3.44611943643906 |
| exp | torch.float16 | 67108864 | 9.37E-05 | 9.55E-05 | 1.96412400380129 |
| exp | torch.float16 | 134217728 | 0.000180913504057874 | 0.000187193179347863 | 3.47109262113439 |
| exp | torch.float16 | 268435456 | 0.00035607748820136 | 0.000369079003576189 | 3.65131630210701 |
| exp | torch.float16 | 536870912 | 0.000707551507124056 | 0.000732363162872692 | 3.50669251620789 |
| exp | torch.bfloat16 | 16777216 | 2.98E-05 | 3.04E-05 | 1.74345594341654 |
| exp | torch.bfloat16 | 33554432 | 4.88E-05 | 5.04E-05 | 3.40217856534821 |
| exp | torch.bfloat16 | 67108864 | 9.32E-05 | 9.62E-05 | 3.29219958210226 |
| exp | torch.bfloat16 | 134217728 | 0.000180999826019009 | 0.000187239318620414 | 3.44723679499521 |
| exp | torch.bfloat16 | 268435456 | 0.000355944503098726 | 0.000369370992605885 | 3.77207384585864 |
| exp | torch.bfloat16 | 536870912 | 0.000707135167128096 | 0.000733066000975668 | 3.66702648277075 |
| exp | torch.float32 | 16777216 | 4.89E-05 | 5.63E-05 | 15.1245314346532 |
| exp | torch.float32 | 33554432 | 9.34E-05 | 9.31E-05 | -0.259945454477446 |
| exp | torch.float32 | 67108864 | 0.000181152504713585 | 0.000180474346658836 | -0.374357536939058 |
| exp | torch.float32 | 134217728 | 0.000356771342922002 | 0.000355627329554409 | -0.3206573034212 |
| exp | torch.float32 | 268435456 | 0.000708404501589636 | 0.00070713268360123 | -0.179532736671163 |
| exp | torch.float32 | 536870912 | 0.00141283582585553 | 0.00140944866385932 | -0.23974208002295 |
</details>
Pull Request resolved: #145746
Approved by: https://github.com/eqy, https://github.com/ngimel
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>1 parent eeb5e1b commit e84bf88
File tree
8 files changed
+77
-21
lines changed- aten/src/ATen
- native/cuda
- test
8 files changed
+77
-21
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
52 | | - | |
53 | 52 | | |
| 53 | + | |
54 | 54 | | |
55 | 55 | | |
56 | 56 | | |
| |||
131 | 131 | | |
132 | 132 | | |
133 | 133 | | |
| 134 | + | |
| 135 | + | |
| 136 | + | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
| 144 | + | |
| 145 | + | |
134 | 146 | | |
135 | 147 | | |
136 | 148 | | |
137 | 149 | | |
138 | 150 | | |
139 | 151 | | |
140 | 152 | | |
141 | | - | |
142 | | - | |
143 | 153 | | |
144 | 154 | | |
145 | | - | |
| 155 | + | |
| 156 | + | |
| 157 | + | |
146 | 158 | | |
147 | 159 | | |
148 | 160 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
61 | 61 | | |
62 | 62 | | |
63 | 63 | | |
| 64 | + | |
64 | 65 | | |
65 | 66 | | |
66 | 67 | | |
| |||
71 | 72 | | |
72 | 73 | | |
73 | 74 | | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
74 | 85 | | |
75 | 86 | | |
76 | 87 | | |
| |||
191 | 202 | | |
192 | 203 | | |
193 | 204 | | |
| 205 | + | |
194 | 206 | | |
195 | | - | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
196 | 219 | | |
197 | 220 | | |
198 | 221 | | |
199 | 222 | | |
200 | 223 | | |
201 | 224 | | |
202 | 225 | | |
| 226 | + | |
203 | 227 | | |
204 | 228 | | |
205 | 229 | | |
206 | 230 | | |
207 | 231 | | |
208 | | - | |
209 | 232 | | |
210 | 233 | | |
211 | 234 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
217 | 217 | | |
218 | 218 | | |
219 | 219 | | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
220 | 223 | | |
221 | | - | |
| 224 | + | |
222 | 225 | | |
223 | 226 | | |
224 | 227 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
351 | 351 | | |
352 | 352 | | |
353 | 353 | | |
354 | | - | |
355 | 354 | | |
| 355 | + | |
356 | 356 | | |
357 | 357 | | |
358 | 358 | | |
359 | 359 | | |
360 | 360 | | |
361 | 361 | | |
362 | 362 | | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
363 | 367 | | |
364 | 368 | | |
365 | 369 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
932 | 932 | | |
933 | 933 | | |
934 | 934 | | |
935 | | - | |
936 | 935 | | |
937 | 936 | | |
938 | 937 | | |
| |||
952 | 951 | | |
953 | 952 | | |
954 | 953 | | |
955 | | - | |
956 | 954 | | |
957 | 955 | | |
958 | 956 | | |
| |||
971 | 969 | | |
972 | 970 | | |
973 | 971 | | |
974 | | - | |
| 972 | + | |
| 973 | + | |
| 974 | + | |
| 975 | + | |
| 976 | + | |
| 977 | + | |
| 978 | + | |
| 979 | + | |
975 | 980 | | |
976 | 981 | | |
977 | 982 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
60 | 60 | | |
61 | 61 | | |
62 | 62 | | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
63 | 67 | | |
64 | 68 | | |
65 | 69 | | |
| |||
88 | 92 | | |
89 | 93 | | |
90 | 94 | | |
| 95 | + | |
91 | 96 | | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
92 | 100 | | |
93 | | - | |
94 | 101 | | |
95 | 102 | | |
96 | 103 | | |
97 | 104 | | |
98 | 105 | | |
99 | | - | |
100 | 106 | | |
101 | 107 | | |
102 | 108 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
| 15 | + | |
| 16 | + | |
15 | 17 | | |
16 | 18 | | |
17 | 19 | | |
18 | 20 | | |
| 21 | + | |
| 22 | + | |
19 | 23 | | |
20 | 24 | | |
21 | | - | |
22 | 25 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
55 | 55 | | |
56 | 56 | | |
57 | 57 | | |
| |||
65 | 65 | | |
66 | 66 | | |
67 | 67 | | |
68 | | - | |
69 | | - | |
| 68 | + | |
| 69 | + | |
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
| |||
0 commit comments