Skip to content

Commit fb93c39

Browse files
suopytorchmergebot
authored andcommitted
[build] Split .cu to improve compile times (#81193)
The goal is to speed up CUDA builds. I was looking at bulid times and found that we have large CUDA compilation units that take forever to compile and make parallelism less effective. This PR splits them up into different `.cu` files so we can parallelize compilation better. We've done this sort of thing in the past with some success. With a cold build, timing before: 5m42.019s, timing after: 4m30.275s. That's a speedup of 18.1% for me. Behaviorally this should be a no-op, I'm just moving code around. There is still more we can do here but I did most of the ones that are copypasta. The full list of remaining chonky compilation units is [here](https://gist.github.com/suo/0dc217733f40f59898a8cc4f60529d60). ## Details Here's a screenshot from a ninja trace, with the following command: ``` MAX_JOBS=64 CCACHE_DISABLE=1 TORCH_CUDA_ARCH_LIST=Ampere BUILD_CAFFE2_OPS=0 USE_FBGEMM=0 USE_DISTRIBUTED=0 USE_MKLDNN=0 BUILD_TEST=0 USE_GOLD_LINKER=1 USE_OPENMP=1 USE_NCCL=0 DEBUG=0 python setup.py develop ``` <img width="1475" alt="image" src="https://user-images.githubusercontent.com/1617424/178170276-ee0e5eb0-4c16-4b86-b4af-2a9e615b7f5f.png"> ([source trace](https://gist.github.com/suo/5f5458f2630f9ab6dcbea6989e892195), which you can visualize in [perfetto](https://ui.perfetto.dev/)) After this PR, we get somewhat better utilization (although there is plenty still left to do): <img width="1466" alt="image" src="https://user-images.githubusercontent.com/1617424/178178944-63ca9ff0-9cd3-4008-9a6d-d8623b5148c5.png"> ([source trace](https://gist.github.com/suo/5607335bcd4bd412d42b0c9334259184)) Pull Request resolved: #81193 Approved by: https://github.com/cpuhrsch, https://github.com/malfet
1 parent 282de55 commit fb93c39

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+2952
-1914
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
2+
#include <ATen/cuda/CUDAConfig.h>
3+
#include <ATen/cuda/cub.cuh>
4+
5+
namespace at {
6+
namespace cuda {
7+
namespace cub {
8+
9+
template <typename key_t>
10+
void radix_sort_keys(
11+
const key_t* keys_in,
12+
key_t* keys_out,
13+
int64_t n,
14+
bool descending,
15+
int64_t begin_bit,
16+
int64_t end_bit) {
17+
TORCH_CHECK(
18+
n <= std::numeric_limits<int>::max(),
19+
"cub sort does not support sorting more than INT_MAX elements");
20+
using key_t_ = typename detail::cuda_type<key_t>::type;
21+
22+
const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
23+
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
24+
25+
if (descending) {
26+
CUB_WRAPPER(
27+
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
28+
keys_in_,
29+
keys_out_,
30+
n,
31+
begin_bit,
32+
end_bit,
33+
c10::cuda::getCurrentCUDAStream());
34+
} else {
35+
CUB_WRAPPER(
36+
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
37+
keys_in_,
38+
keys_out_,
39+
n,
40+
begin_bit,
41+
end_bit,
42+
c10::cuda::getCurrentCUDAStream());
43+
}
44+
}
45+
46+
template <typename scalar_t>
47+
void unique(
48+
const scalar_t* input,
49+
scalar_t* output,
50+
int64_t* num_selected_out,
51+
int64_t num_items) {
52+
TORCH_CHECK(
53+
num_items <= std::numeric_limits<int>::max(),
54+
"cub unique does not support more than INT_MAX elements");
55+
CUB_WRAPPER(
56+
NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
57+
input,
58+
output,
59+
num_selected_out,
60+
num_items,
61+
at::cuda::getCurrentCUDAStream());
62+
}
63+
64+
template <typename scalar_t>
65+
void run_length_encode(
66+
const scalar_t* input,
67+
scalar_t* output,
68+
int64_t* counts_out,
69+
int64_t* length_out,
70+
int64_t num_items) {
71+
TORCH_CHECK(
72+
num_items <= std::numeric_limits<int>::max(),
73+
"cub run_length_encode does not support more than INT_MAX elements");
74+
CUB_WRAPPER(
75+
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
76+
input,
77+
output,
78+
counts_out,
79+
length_out,
80+
num_items,
81+
at::cuda::getCurrentCUDAStream());
82+
}
83+
84+
#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
85+
template void radix_sort_keys( \
86+
const scalar_t* keys_in, \
87+
scalar_t* keys_out, \
88+
int64_t n, \
89+
bool descending, \
90+
int64_t begin_bit, \
91+
int64_t end_bit); \
92+
template void unique( \
93+
const scalar_t* input, \
94+
scalar_t* output, \
95+
int64_t* num_selected_out, \
96+
int64_t num_items); \
97+
template void run_length_encode( \
98+
const scalar_t* input, \
99+
scalar_t* output, \
100+
int64_t* counts_out, \
101+
int64_t* length_out, \
102+
int64_t n);
103+
104+
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
105+
106+
} // namespace cub
107+
} // namespace cuda
108+
} // namespace at
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#define TORCH_ASSERT_NO_OPERATORS
2+
#include <ATen/cuda/CUDAConfig.h>
3+
#include <ATen/cuda/cub.cuh>
4+
5+
namespace at {
6+
namespace cuda {
7+
namespace cub {
8+
namespace detail {
9+
10+
template <typename key_t, int value_size>
11+
void radix_sort_pairs_impl(
12+
const key_t* keys_in,
13+
key_t* keys_out,
14+
const OpaqueType<value_size>* values_in,
15+
OpaqueType<value_size>* values_out,
16+
int64_t n,
17+
bool descending,
18+
int64_t begin_bit,
19+
int64_t end_bit) {
20+
TORCH_CHECK(
21+
n <= std::numeric_limits<int>::max(),
22+
"cub sort does not support sorting more than INT_MAX elements");
23+
using key_t_ = typename detail::cuda_type<key_t>::type;
24+
25+
auto allocator = c10::cuda::CUDACachingAllocator::get();
26+
c10::DataPtr keys_out_owner;
27+
28+
if (keys_out == nullptr) {
29+
keys_out_owner = allocator->allocate(n * sizeof(key_t));
30+
keys_out = reinterpret_cast<key_t*>(keys_out_owner.get());
31+
}
32+
33+
const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
34+
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
35+
36+
if (descending) {
37+
CUB_WRAPPER(
38+
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
39+
keys_in_,
40+
keys_out_,
41+
values_in,
42+
values_out,
43+
n,
44+
begin_bit,
45+
end_bit,
46+
c10::cuda::getCurrentCUDAStream());
47+
} else {
48+
CUB_WRAPPER(
49+
NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
50+
keys_in_,
51+
keys_out_,
52+
values_in,
53+
values_out,
54+
n,
55+
begin_bit,
56+
end_bit,
57+
c10::cuda::getCurrentCUDAStream());
58+
}
59+
}
60+
61+
#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
62+
template void radix_sort_pairs_impl( \
63+
const key_t* keys_in, \
64+
key_t* keys_out, \
65+
const OpaqueType<value_size>* values_in, \
66+
OpaqueType<value_size>* values_out, \
67+
int64_t n, \
68+
bool descending, \
69+
int64_t begin_bit, \
70+
int64_t end_bit);
71+
72+
AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
73+
AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
74+
AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
75+
AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
76+
AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
77+
AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
78+
79+
#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
80+
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)
81+
82+
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
83+
84+
// BFloat16 Radix sort is supported from ROCm 4.5 onwards
85+
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
86+
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
87+
#endif
88+
89+
} // namespace detail
90+
91+
} // namespace cub
92+
} // namespace cuda
93+
} // namespace at

aten/src/ATen/cuda/cub.cu

Lines changed: 0 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -5,118 +5,6 @@
55
namespace at {
66
namespace cuda {
77
namespace cub {
8-
namespace detail {
9-
10-
template<typename key_t, int value_size>
11-
void radix_sort_pairs_impl(
12-
const key_t *keys_in, key_t *keys_out,
13-
const OpaqueType<value_size> *values_in, OpaqueType<value_size> *values_out,
14-
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) {
15-
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
16-
"cub sort does not support sorting more than INT_MAX elements");
17-
using key_t_ = typename detail::cuda_type<key_t>::type;
18-
19-
auto allocator = c10::cuda::CUDACachingAllocator::get();
20-
c10::DataPtr keys_out_owner;
21-
22-
if (keys_out == nullptr) {
23-
keys_out_owner = allocator->allocate(n * sizeof(key_t));
24-
keys_out = reinterpret_cast<key_t *>(keys_out_owner.get());
25-
}
26-
27-
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
28-
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
29-
30-
if (descending) {
31-
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairsDescending,
32-
keys_in_, keys_out_, values_in, values_out, n,
33-
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
34-
} else {
35-
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortPairs,
36-
keys_in_, keys_out_, values_in, values_out, n,
37-
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
38-
}
39-
}
40-
41-
#define AT_INSTANTIATE_SORT_PAIRS(key_t, value_size) \
42-
template void radix_sort_pairs_impl( \
43-
const key_t *keys_in, key_t *keys_out, \
44-
const OpaqueType<value_size> *values_in, \
45-
OpaqueType<value_size> *values_out, \
46-
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit);
47-
48-
AT_INSTANTIATE_SORT_PAIRS(int32_t, 1)
49-
AT_INSTANTIATE_SORT_PAIRS(int32_t, 2)
50-
AT_INSTANTIATE_SORT_PAIRS(int32_t, 4)
51-
AT_INSTANTIATE_SORT_PAIRS(int64_t, 1)
52-
AT_INSTANTIATE_SORT_PAIRS(int64_t, 2)
53-
AT_INSTANTIATE_SORT_PAIRS(int64_t, 4)
54-
55-
#define AT_INSTANTIATE_SORT_PAIRS_8(scalar_t, ScalarType) \
56-
AT_INSTANTIATE_SORT_PAIRS(scalar_t, 8)
57-
58-
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTANTIATE_SORT_PAIRS_8)
59-
60-
// BFloat16 Radix sort is supported from ROCm 4.5 onwards
61-
#if !AT_ROCM_ENABLED() || (AT_ROCM_ENABLED() && ROCM_VERSION >= 40500)
62-
AT_INSTANTIATE_SORT_PAIRS(c10::BFloat16, 8)
63-
#endif
64-
65-
} // namespace detail
66-
67-
template<typename key_t>
68-
void radix_sort_keys(
69-
const key_t *keys_in, key_t *keys_out,
70-
int64_t n, bool descending, int64_t begin_bit, int64_t end_bit) {
71-
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
72-
"cub sort does not support sorting more than INT_MAX elements");
73-
using key_t_ = typename detail::cuda_type<key_t>::type;
74-
75-
const key_t_ *keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
76-
key_t_ *keys_out_ = reinterpret_cast<key_t_*>(keys_out);
77-
78-
if (descending) {
79-
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeysDescending,
80-
keys_in_, keys_out_, n,
81-
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
82-
} else {
83-
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceRadixSort::SortKeys,
84-
keys_in_, keys_out_, n,
85-
begin_bit, end_bit, c10::cuda::getCurrentCUDAStream());
86-
}
87-
}
88-
89-
template<typename scalar_t>
90-
void unique(const scalar_t *input, scalar_t *output, int64_t *num_selected_out, int64_t num_items) {
91-
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
92-
"cub unique does not support more than INT_MAX elements");
93-
CUB_WRAPPER(NO_ROCM(at_cuda_detail)::cub::DeviceSelect::Unique,
94-
input, output, num_selected_out, num_items, at::cuda::getCurrentCUDAStream());
95-
}
96-
97-
template <typename scalar_t>
98-
void run_length_encode(const scalar_t *input, scalar_t *output, int64_t *counts_out,
99-
int64_t *length_out, int64_t num_items) {
100-
TORCH_CHECK(num_items <= std::numeric_limits<int>::max(),
101-
"cub run_length_encode does not support more than INT_MAX elements");
102-
CUB_WRAPPER(
103-
NO_ROCM(at_cuda_detail)::cub::DeviceRunLengthEncode::Encode,
104-
input, output, counts_out, length_out, num_items,
105-
at::cuda::getCurrentCUDAStream());
106-
}
107-
108-
#define AT_INSTATIATE_CUB_TEMPLATES(scalar_t, ScalarType) \
109-
template void radix_sort_keys( \
110-
const scalar_t *keys_in, scalar_t *keys_out, int64_t n, \
111-
bool descending, int64_t begin_bit, int64_t end_bit); \
112-
template void unique( \
113-
const scalar_t *input, scalar_t *output, \
114-
int64_t *num_selected_out, int64_t num_items); \
115-
template void run_length_encode( \
116-
const scalar_t *input, scalar_t *output, int64_t *counts_out, \
117-
int64_t *length_out, int64_t n);
118-
119-
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, AT_INSTATIATE_CUB_TEMPLATES)
1208

1219
namespace {
12210
template <typename scalar_t>

0 commit comments

Comments
 (0)