Skip to content

Commit cf5c3bb

Browse files
Natalia Gimelsheinfacebook-github-bot
authored andcommitted
make range functions respect current stream (#21619)
Summary: Stream is not respected on range/linspace/logspace functions, which contributes to #21589 (this is not a complete solution for that issue). Pull Request resolved: #21619 Differential Revision: D15769666 Pulled By: ezyang fbshipit-source-id: 7c036f7aecb3119430c4d432775cad98a5028fa8
1 parent 9241c4b commit cf5c3bb

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

aten/src/ATen/native/cuda/RangeFactories.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
#include <ATen/NativeFunctions.h>
33
#include <ATen/AccumulateType.h>
44
#include <ATen/cuda/Exceptions.h>
5+
#include <ATen/cuda/CUDAContext.h>
56
#include <cmath>
67
#include <limits>
78

89
#include <thrust/device_ptr.h>
910
#include <thrust/sequence.h>
11+
#include <thrust/execution_policy.h>
1012

1113
namespace at {
1214
namespace native {
@@ -56,7 +58,9 @@ Tensor& linspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
5658
scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
5759
LinspaceOp<scalar_t> linspace_method(scalar_start, step);
5860
thrust::device_ptr<scalar_t> data_(r.data<scalar_t>());
59-
thrust::tabulate(data_, data_ + steps, linspace_method);
61+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
62+
auto policy = thrust::cuda::par.on(stream);
63+
thrust::tabulate(policy, data_, data_ + steps, linspace_method);
6064
});
6165
}
6266

@@ -87,7 +91,9 @@ Tensor& logspace_cuda_out(Tensor& result, Scalar start, Scalar end, int64_t step
8791
scalar_t step = (scalar_end - scalar_start) / static_cast<scalar_t>(steps - 1);
8892
LogspaceOp<scalar_t> logspace_method(scalar_start, step, scalar_base);
8993
thrust::device_ptr<scalar_t> data_(r.data<scalar_t>());
90-
thrust::tabulate(data_, data_ + steps, logspace_method);
94+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
95+
auto policy = thrust::cuda::par.on(stream);
96+
thrust::tabulate(policy, data_, data_ + steps, logspace_method);
9197
});
9298
}
9399

@@ -117,8 +123,10 @@ Tensor& range_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
117123
}
118124
Tensor r = result.is_contiguous() ? result : result.contiguous();
119125
LinspaceOp<scalar_t, accscalar_t> linspace_method(xstart, xstep);
126+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
127+
auto policy = thrust::cuda::par.on(stream);
120128
thrust::device_ptr<scalar_t> data_ptr(r.data<scalar_t>());
121-
thrust::tabulate(data_ptr, data_ptr + size, linspace_method);
129+
thrust::tabulate(policy, data_ptr, data_ptr + size, linspace_method);
122130

123131
if (!result.is_contiguous()) {
124132
result.copy_(r);
@@ -168,8 +176,10 @@ Tensor& arange_cuda_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
168176
}
169177
Tensor r = result.is_contiguous() ? result : result.contiguous();
170178
LinspaceOp<scalar_t, accscalar_t> linspace_method(xstart, xstep);
179+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
180+
auto policy = thrust::cuda::par.on(stream);
171181
thrust::device_ptr<scalar_t> data_ptr(r.data<scalar_t>());
172-
thrust::tabulate(data_ptr, data_ptr + size, linspace_method);
182+
thrust::tabulate(policy, data_ptr, data_ptr + size, linspace_method);
173183

174184
if (!result.is_contiguous()) {
175185
result.copy_(r);

0 commit comments

Comments
 (0)