|
2 | 2 | #include <ATen/NativeFunctions.h> |
3 | 3 | #include <ATen/AccumulateType.h> |
4 | 4 | #include <ATen/cuda/Exceptions.h> |
| 5 | +#include <ATen/cuda/CUDAContext.h> |
5 | 6 | #include <cmath> |
6 | 7 | #include <limits> |
7 | 8 |
|
8 | 9 | #include <thrust/device_ptr.h> |
9 | 10 | #include <thrust/sequence.h> |
| 11 | +#include <thrust/execution_policy.h> |
10 | 12 |
|
11 | 13 | namespace at { |
12 | 14 | namespace native { |
@@ -56,7 +58,9 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step |
56 | 58 | scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1); |
57 | 59 | LinspaceOp<scalar_t> linspace_method(scalar_start, step); |
58 | 60 | thrust::device_ptr<scalar_t> data_(r.data<scalar_t>()); |
59 | | - thrust::tabulate(data_, data_ + steps, linspace_method); |
| 61 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 62 | + auto policy = thrust::cuda::par.on(stream); |
| 63 | + thrust::tabulate(policy, data_, data_ + steps, linspace_method); |
60 | 64 | }); |
61 | 65 | } |
62 | 66 |
|
@@ -87,7 +91,9 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step |
87 | 91 | scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1); |
88 | 92 | LogspaceOp<scalar_t> logspace_method(scalar_start, step, scalar_base); |
89 | 93 | thrust::device_ptr<scalar_t> data_(r.data<scalar_t>()); |
90 | | - thrust::tabulate(data_, data_ + steps, logspace_method); |
| 94 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 95 | + auto policy = thrust::cuda::par.on(stream); |
| 96 | + thrust::tabulate(policy, data_, data_ + steps, logspace_method); |
91 | 97 | }); |
92 | 98 | } |
93 | 99 |
|
@@ -117,8 +123,10 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { |
117 | 123 | } |
118 | 124 | Tensor r = result.is_contiguous() ? result : result.contiguous(); |
119 | 125 | LinspaceOp<scalar_t, accscalar_t> linspace_method(xstart, xstep); |
| 126 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 127 | + auto policy = thrust::cuda::par.on(stream); |
120 | 128 | thrust::device_ptr<scalar_t> data_ptr(r.data<scalar_t>()); |
121 | | - thrust::tabulate(data_ptr, data_ptr + size, linspace_method); |
| 129 | + thrust::tabulate(policy, data_ptr, data_ptr + size, linspace_method); |
122 | 130 |
|
123 | 131 | if (!result.is_contiguous()) { |
124 | 132 | result.copy_(r); |
@@ -168,8 +176,10 @@ Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) { |
168 | 176 | } |
169 | 177 | Tensor r = result.is_contiguous() ? result : result.contiguous(); |
170 | 178 | LinspaceOp<scalar_t, accscalar_t> linspace_method(xstart, xstep); |
| 179 | + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| 180 | + auto policy = thrust::cuda::par.on(stream); |
171 | 181 | thrust::device_ptr<scalar_t> data_ptr(r.data<scalar_t>()); |
172 | | - thrust::tabulate(data_ptr, data_ptr + size, linspace_method); |
| 182 | + thrust::tabulate(policy, data_ptr, data_ptr + size, linspace_method); |
173 | 183 |
|
174 | 184 | if (!result.is_contiguous()) { |
175 | 185 | result.copy_(r); |
|
0 commit comments