Skip to content

Commit d975609

Browse files
gchananfacebook-github-bot
authored andcommitted
Split BinaryCompareKernel.cu into a file-per-kernel to speed up compilation. (#33871)
Summary: Pull Request resolved: #33871 Test Plan: Imported from OSS Differential Revision: D20140862 Pulled By: gchanan fbshipit-source-id: a4fde38c1c7c5905e3855fa490ea2e87bb24c703
1 parent 5eacdfb commit d975609

File tree

7 files changed

+142
-57
lines changed

7 files changed

+142
-57
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
#include <ATen/native/cuda/zmath.cuh>
7+
8+
9+
// NOTE: CUDA on Windows requires that the enclosing function
10+
// of a __device__ lambda not have internal linkage.
11+
12+
namespace at { namespace native {
13+
14+
void eq_kernel_cuda(TensorIterator& iter) {
15+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "eq_cuda", [&]() {
16+
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
17+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
18+
return a == b;
19+
});
20+
});
21+
}
22+
23+
REGISTER_DISPATCH(eq_stub, &eq_kernel_cuda);
24+
25+
}} // namespace at::native
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
7+
8+
// NOTE: CUDA on Windows requires that the enclosing function
9+
// of a __device__ lambda not have internal linkage.
10+
11+
namespace at { namespace native {
12+
13+
void ge_kernel_cuda(TensorIterator& iter) {
14+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ge_cuda", [&]() {
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a >= b;
17+
});
18+
});
19+
}
20+
21+
REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
22+
23+
}} // namespace at::native
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
7+
8+
// NOTE: CUDA on Windows requires that the enclosing function
9+
// of a __device__ lambda not have internal linkage.
10+
11+
namespace at { namespace native {
12+
13+
void gt_kernel_cuda(TensorIterator& iter) {
14+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "gt_cuda", [&]() {
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a > b;
17+
});
18+
});
19+
}
20+
21+
REGISTER_DISPATCH(gt_stub, &gt_kernel_cuda);
22+
23+
}} // namespace at::native
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
7+
8+
// NOTE: CUDA on Windows requires that the enclosing function
9+
// of a __device__ lambda not have internal linkage.
10+
11+
namespace at { namespace native {
12+
13+
void le_kernel_cuda(TensorIterator& iter) {
14+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "le_cuda", [&]() {
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a <= b;
17+
});
18+
});
19+
}
20+
21+
REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
22+
23+
}} // namespace at::native
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
7+
8+
// NOTE: CUDA on Windows requires that the enclosing function
9+
// of a __device__ lambda not have internal linkage.
10+
11+
namespace at { namespace native {
12+
13+
void lt_kernel_cuda(TensorIterator& iter) {
14+
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "lt_cuda", [&]() {
15+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
16+
return a < b;
17+
});
18+
});
19+
}
20+
21+
REGISTER_DISPATCH(lt_stub, &lt_kernel_cuda);
22+
23+
}} // namespace at::native
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#include <ATen/Dispatch.h>
2+
#include <ATen/native/BinaryOps.h>
3+
#include <ATen/native/DispatchStub.h>
4+
#include <ATen/native/TensorIterator.h>
5+
#include <ATen/native/cuda/Loops.cuh>
6+
#include <ATen/native/cuda/zmath.cuh>
7+
8+
9+
// NOTE: CUDA on Windows requires that the enclosing function
10+
// of a __device__ lambda not have internal linkage.
11+
12+
namespace at { namespace native {
13+
14+
void ne_kernel_cuda(TensorIterator& iter) {
15+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "ne_cuda", [&]() {
16+
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
17+
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
18+
return a != b;
19+
});
20+
});
21+
}
22+
23+
REGISTER_DISPATCH(ne_stub, &ne_kernel_cuda);
24+
25+
}} // namespace at::native

aten/src/ATen/native/cuda/BinaryCompareKernel.cu renamed to aten/src/ATen/native/cuda/MaxMinElementwiseKernel.cu

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,6 @@
1111

1212
namespace at { namespace native {
1313

14-
void lt_kernel_cuda(TensorIterator& iter) {
15-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "lt_cuda", [&]() {
16-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
17-
return a < b;
18-
});
19-
});
20-
}
21-
22-
void le_kernel_cuda(TensorIterator& iter) {
23-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "le_cuda", [&]() {
24-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
25-
return a <= b;
26-
});
27-
});
28-
}
29-
30-
void gt_kernel_cuda(TensorIterator& iter) {
31-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "gt_cuda", [&]() {
32-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
33-
return a > b;
34-
});
35-
});
36-
}
37-
38-
void ge_kernel_cuda(TensorIterator& iter) {
39-
AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ge_cuda", [&]() {
40-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool {
41-
return a >= b;
42-
});
43-
});
44-
}
45-
46-
void eq_kernel_cuda(TensorIterator& iter) {
47-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "eq_cuda", [&]() {
48-
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
49-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
50-
return a == b;
51-
});
52-
});
53-
}
54-
55-
void ne_kernel_cuda(TensorIterator& iter) {
56-
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "ne_cuda", [&]() {
57-
using thrust_t = typename ztype_cuda<scalar_t>::thrust_t;
58-
gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool {
59-
return a != b;
60-
});
61-
});
62-
}
63-
6414
void max_elementwise_kernel_cuda(TensorIterator& iter) {
6515
if (iter.dtype() == ScalarType::Bool) {
6616
gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool {
@@ -119,13 +69,6 @@ void min_elementwise_kernel_cuda(TensorIterator& iter) {
11969
}
12070
}
12171

122-
123-
REGISTER_DISPATCH(lt_stub, &lt_kernel_cuda);
124-
REGISTER_DISPATCH(le_stub, &le_kernel_cuda);
125-
REGISTER_DISPATCH(gt_stub, &gt_kernel_cuda);
126-
REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
127-
REGISTER_DISPATCH(eq_stub, &eq_kernel_cuda);
128-
REGISTER_DISPATCH(ne_stub, &ne_kernel_cuda);
12972
REGISTER_DISPATCH(max_elementwise_stub, &max_elementwise_kernel_cuda);
13073
REGISTER_DISPATCH(min_elementwise_stub, &min_elementwise_kernel_cuda);
13174

0 commit comments

Comments
 (0)