|
11 | 11 |
|
12 | 12 | namespace at { namespace native { |
13 | 13 |
|
14 | | -void lt_kernel_cuda(TensorIterator& iter) { |
15 | | - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "lt_cuda", [&]() { |
16 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { |
17 | | - return a < b; |
18 | | - }); |
19 | | - }); |
20 | | -} |
21 | | - |
22 | | -void le_kernel_cuda(TensorIterator& iter) { |
23 | | - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "le_cuda", [&]() { |
24 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { |
25 | | - return a <= b; |
26 | | - }); |
27 | | - }); |
28 | | -} |
29 | | - |
30 | | -void gt_kernel_cuda(TensorIterator& iter) { |
31 | | - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "gt_cuda", [&]() { |
32 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { |
33 | | - return a > b; |
34 | | - }); |
35 | | - }); |
36 | | -} |
37 | | - |
38 | | -void ge_kernel_cuda(TensorIterator& iter) { |
39 | | - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBool, iter.common_dtype(), "ge_cuda", [&]() { |
40 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(scalar_t a, scalar_t b) -> bool { |
41 | | - return a >= b; |
42 | | - }); |
43 | | - }); |
44 | | -} |
45 | | - |
46 | | -void eq_kernel_cuda(TensorIterator& iter) { |
47 | | - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "eq_cuda", [&]() { |
48 | | - using thrust_t = typename ztype_cuda<scalar_t>::thrust_t; |
49 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool { |
50 | | - return a == b; |
51 | | - }); |
52 | | - }); |
53 | | -} |
54 | | - |
55 | | -void ne_kernel_cuda(TensorIterator& iter) { |
56 | | - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kHalf, kBool, iter.common_dtype(), "ne_cuda", [&]() { |
57 | | - using thrust_t = typename ztype_cuda<scalar_t>::thrust_t; |
58 | | - gpu_kernel_with_scalars(iter, []GPU_LAMBDA(thrust_t a, thrust_t b) -> bool { |
59 | | - return a != b; |
60 | | - }); |
61 | | - }); |
62 | | -} |
63 | | - |
64 | 14 | void max_elementwise_kernel_cuda(TensorIterator& iter) { |
65 | 15 | if (iter.dtype() == ScalarType::Bool) { |
66 | 16 | gpu_kernel(iter, []GPU_LAMBDA(bool a, bool b) -> bool { |
@@ -119,13 +69,6 @@ void min_elementwise_kernel_cuda(TensorIterator& iter) { |
119 | 69 | } |
120 | 70 | } |
121 | 71 |
|
122 | | - |
123 | | -REGISTER_DISPATCH(lt_stub, <_kernel_cuda); |
124 | | -REGISTER_DISPATCH(le_stub, &le_kernel_cuda); |
125 | | -REGISTER_DISPATCH(gt_stub, >_kernel_cuda); |
126 | | -REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda); |
127 | | -REGISTER_DISPATCH(eq_stub, &eq_kernel_cuda); |
128 | | -REGISTER_DISPATCH(ne_stub, &ne_kernel_cuda); |
129 | 72 | REGISTER_DISPATCH(max_elementwise_stub, &max_elementwise_kernel_cuda); |
130 | 73 | REGISTER_DISPATCH(min_elementwise_stub, &min_elementwise_kernel_cuda); |
131 | 74 |
|
|
0 commit comments