-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Make TensorIterator stop promoting types by copying #28344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` ghstack-source-id: 1153268 Pull Request resolved: #28344
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` [ghstack-poisoned]
| constexpr int ntensors = traits::arity + 1; | ||
|
|
||
| // Copying strides to temporary array helps auto vectorization in older GCC | ||
| // versions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which gcc versions need this? Note that gcc 5 is no longer supported, so workarounds for it are not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know. It was copy-pasted from the existing code and modified. Let me try to find the answer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I searched gcc's changelogs at for example https://gcc.gnu.org/gcc-7/changes.html for vectoriz for different gcc versions. The only thing that interests me is at https://gcc.gnu.org/gcc-5/changes.html. I don't know if this is related.
@colesbury might know the answer.
|
Please also include CPU benchmarks with and without type promotions |
|
@VitalyFedyunin There is very little change in the performance. The benchmark is as follows: import torch
print(torch.__version__)
print(torch.version.git_version)
_100M = 100 * 1024 ** 2
r = torch.randn(_100M, dtype=torch.float32, device='cpu')
d = torch.randn(_100M, dtype=torch.float64, device='cpu')
%timeit r.add_(d);before after |
|
Without promotion on CPU: import torch
print(torch.__version__)
print(torch.version.git_version)
_100M = 100 * 1024 ** 2
a = torch.randn(_100M, dtype=torch.float32, device='cpu')
b = torch.randn(_100M, dtype=torch.float32, device='cpu')
%timeit a.add_(b);before after |
|
Without promotion on GPU: import torch
print(torch.__version__)
print(torch.version.git_version)
_100M = 100 * 1024 ** 2
a = torch.randn(_100M, dtype=torch.float32, device='cuda')
b = torch.randn(_100M, dtype=torch.float32, device='cuda')
torch.cuda.synchronize()
%timeit a.add_(b); torch.cuda.synchronize()before after |
|
The case when there is no promotion is dispatched to the original code at https://github.com/pytorch/pytorch/pull/28344/files#diff-0d1178f1a4ce15aeb760d251974e6924R242 |
|
I messed up the PRs and they are merged by ghstack... Will resubmit soon. |
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 [ghstack-poisoned]
Summary: Pull Request resolved: #28427 Fixes: #26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f65 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: #28344 Test Plan: Imported from OSS Differential Revision: D18170997 Pulled By: ezyang fbshipit-source-id: 9c82c1c89583f3e6202c5d790b9b73ad9f960fad
Summary: Pull Request resolved: pytorch/pytorch#28427 Fixes: pytorch/pytorch#26401 This PR fixes the issue by using the newly added dynamic cast inside `TensorIterator` so that instead of converting the type at the beginning (which generates extra kernel launches), the `TensorIterator` do a load-cast-compute-store for each element while looping. So there is only one read and one write of memory. **nvprof:** ```python import torch _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() torch.cuda.profiler.start() r.add_(d) torch.cuda.profiler.stop() torch.cuda.synchronize() ``` ``` ==11407== NVPROF is profiling process 11407, command: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling application: /home/xgao/anaconda3/bin/python simple.py ==11407== Profiling result: Type Time(%) Time Calls Avg Min Max Name GPU activities: 100.00% 2.0611ms 1 2.0611ms 2.0611ms 2.0611ms _ZN2at6native18elementwise_kernelILi512ELi1EZNS0_15gpu_kernel_implIZZZNS0_15add_kernel_cudaERNS_14TensorIteratorEN3c106ScalarEENKUlvE_clEvENKUlvE1_clEvEUlddE_EEvS4_RKT_EUliE_EEviT1_ API calls: 100.00% 1.05006s 1 1.05006s 1.05006s 1.05006s cudaLaunchKernel 0.00% 2.7740us 2 1.3870us 673ns 2.1010us cudaGetDevice 0.00% 2.3730us 1 2.3730us 2.3730us 2.3730us cudaSetDevice 0.00% 830ns 1 830ns 830ns 830ns cudaGetLastError ``` **benchmark** ```python import torch print(torch.__version__) print(torch.version.git_version) _100M = 100 * 1024 ** 2 r = torch.randn(_100M, dtype=torch.float32, device='cuda') d = torch.randn(_100M, dtype=torch.float64, device='cuda') torch.cuda.synchronize() %timeit r.add_(d); torch.cuda.synchronize() ``` original ``` 1.4.0a0+7d277b0 7d277b0670eb1f9098a7e098e93b20453e8b5c9f 6.83 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` after ``` 1.4.0a0+f0f2f65 f0f2f654cba9b8c569f0bcd583732bbc891f80b2 2.08 ms ± 139 ns per loop (mean ± std. dev. of 7 runs, 100 loops each) ``` For more benchmark, see: pytorch/pytorch#28344 Test Plan: Imported from OSS Differential Revision: D18170997 Pulled By: ezyang fbshipit-source-id: 9c82c1c89583f3e6202c5d790b9b73ad9f960fad
Stack from ghstack:
Fixes: #26401
This PR fixes the issue by using the newly added dynamic cast inside
TensorIteratorso that instead of converting the type at the beginning(which generates extra kernel launches), the
TensorIteratordo aload-cast-compute-store for each element while looping. So there is only
one read and one write of memory.
nvprof:
benchmark
original
after