-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[WIP] Flip a tensor (CPU + CUDA implementation) #6867
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
acf9515
added CPU torch.flip function
weiyangfb 881c88b
added back third_party/eigen, third_party/onnx
weiyangfb af549df
checking and erasing duplicated dims
weiyangfb 40a906f
try to revert third_party/eigen
weiyangfb debe079
try to revert third_party/onnx
weiyangfb fbc4193
try to revert aten/src/ATen/native/native_functions.yaml
weiyangfb dfae695
rm no needed comments at test/test_cuda.py
weiyangfb c3abcc0
CUDA implementation of flip a tensor, verbose version with logs
weiyangfb b66d3b4
cleaned up logs, added stress test for cuda flip
weiyangfb 3539dc0
try to untrack third_party/eigen and third_party/onnx
weiyangfb d8abb9c
added back third_party/eigen and third_party/onnx
weiyangfb 86dff1b
try to revert changes in third_party/eigen and third_party/onnx
weiyangfb 205bd32
addressed comments, except for 1) 'You should do this in 1 copy inste…
weiyangfb 0a44235
revert test/test_autograd.py and test/test_torch.py to remove error c…
weiyangfb c7a9b89
nit fixes
weiyangfb 9590216
more nit fixes
weiyangfb a5537c2
more nit fixes
weiyangfb ec87215
working to support non-contiguous case in flip (cuda)
weiyangfb e6d9ae9
nits, and [WIP] support non-contiguous case in cuda
weiyangfb f385f42
nits
weiyangfb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| #include "ATen/ATen.h" | ||
| #include "ATen/ExpandUtils.h" | ||
| #include "ATen/NativeFunctions.h" | ||
| #include <functional> | ||
| #include <numeric> | ||
| #include <vector> | ||
|
|
||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| Tensor reverse_dim(const Tensor& t, int64_t dim) { | ||
| Tensor index = at::arange(t.type().toScalarType(at::ScalarType::Long), t.size(dim) - 1, -1, -1); | ||
| return t.index_select(dim, index); | ||
| } | ||
|
|
||
| Tensor flip_cpu(const Tensor& self, IntList dims) { | ||
|
|
||
| int64_t total_dims = self.dim(), flip_dims_size = dims.size(); | ||
|
|
||
| // check if number of axis in dim is valid | ||
| if (flip_dims_size == 0) { | ||
| std::stringstream ss; | ||
| ss << "expected input tensor dims not empty, " | ||
| << "but got tensor dims size=" << flip_dims_size; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check duplicates in dims | ||
| auto flip_dims_v = std::vector<int64_t>(dims); | ||
| flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end()); | ||
| if ((int64_t)flip_dims_v.size() < flip_dims_size) { | ||
| std::stringstream ss; | ||
| ss << "dims has duplicates, " | ||
| << "original flip dims size=" << flip_dims_size << ", " | ||
| << "but unique flip dims size= " << flip_dims_v.size(); | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check len of dims | ||
| if (flip_dims_size > total_dims) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims size <= tensor total dims, " | ||
| << "but got flip dims size=" << flip_dims_size << " and " | ||
| << "tensor total dim=" << total_dims; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check if dims axis within range | ||
| int64_t min_d = total_dims, max_d = 0; | ||
| for (auto d : dims) { | ||
| min_d = std::min(min_d, d); | ||
| max_d = std::max(max_d, d); | ||
| } | ||
|
|
||
| if (min_d < 0) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims axis >= 0, " | ||
| << "but got min flip dims=" << min_d; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| if (max_d >= total_dims) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims axis < tensor total dims, " | ||
| << "but got max flip dims=" << max_d << " and " | ||
| << "tensor total dim=" << total_dims; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| Tensor out_t = self.clone(); | ||
| for (auto d : dims) { | ||
| out_t = reverse_dim(out_t, d); | ||
| } | ||
| return out_t; | ||
| } | ||
|
|
||
| }} // namespace at::native |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| #include "ATen/NativeFunctions.h" | ||
| #include "ATen/ATen.h" | ||
| #include <algorithm> | ||
| #include <sstream> | ||
|
|
||
| #include "ATen/cuda/CUDATensorMethods.cuh" | ||
| #include "ATen/cuda/CUDATypeConversion.cuh" | ||
|
|
||
|
|
||
| namespace at { | ||
| namespace native { | ||
|
|
||
| // Map the index of an element in tensor from 1D to nD | ||
| __device__ __forceinline__ | ||
| void linear_index_to_indices(int64_t linear_index, int64_t* strides, int64_t total_dims, int64_t* indices) { | ||
| int64_t res = linear_index; | ||
| for (int64_t i = 0; i < total_dims; i++) { | ||
| int64_t indices_i = linear_index * total_dims + i; | ||
| indices[indices_i] = res / strides[i]; | ||
| res = res % strides[i]; | ||
| } | ||
| } | ||
|
|
||
| /* | ||
| Map the index of an element in tensor from nD to 1D. A tensor is originally in nD shape, and 1D is the unfolded version of it (a vector). | ||
|
|
||
| Example: given a 3D tensor | ||
| [ | ||
| [ [1, 2], [3, 4] ], | ||
| [ [5, 6], [7, 8] ], | ||
| [ [9, 10], [11, 12] ], | ||
| ] | ||
|
|
||
| Here element 3 has nD index (indice) = (0, 1, 0), and stride = (4, 2, 1). To map nD to 1D, we can use formula: sum(indice[i] * stride[i]). For instance, in the example above, 0 * 4 + 1 * 2 + 0 * 1 = 2, and so the oneD index = 2. | ||
| */ | ||
| __device__ __forceinline__ | ||
| int64_t indices_to_linear_index(int64_t* indices, int64_t total_dims, int64_t* strides, int64_t src_linear_index) { | ||
| int64_t dest_linear_index = 0; | ||
| for (int64_t i = 0; i < total_dims; i++) { | ||
| int64_t indices_i = src_linear_index * total_dims + i; | ||
| dest_linear_index += indices[indices_i] * strides[i]; | ||
| } | ||
| return dest_linear_index; | ||
| } | ||
|
|
||
| template <typename scalar_t> | ||
| __global__ | ||
| void flip_cuda_kernel(scalar_t* in_t, scalar_t* out_t, int64_t N, int64_t* dims, int64_t* indices, | ||
| int64_t flip_dims_size, int64_t* strides, int64_t* shape, int64_t total_dims) { | ||
|
|
||
| int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x; | ||
| if (linear_index >= N) { | ||
| return; | ||
| } | ||
|
|
||
| linear_index_to_indices(linear_index, strides, total_dims, indices); | ||
|
|
||
| // Flip nD index along each desired dimension | ||
| for (int64_t i = 0 ; i < flip_dims_size; i++) { | ||
| int64_t dim = dims[i]; | ||
| int64_t indices_dim = linear_index * total_dims + dim; | ||
| indices[indices_dim] = shape[dim] - 1 - indices[indices_dim]; | ||
| } | ||
| int64_t dest_linear_index = indices_to_linear_index(indices, total_dims, strides, linear_index); | ||
| out_t[linear_index] = in_t[dest_linear_index]; | ||
| } | ||
|
|
||
| // Flip tensor given a list of dims | ||
| Tensor flip_cuda(const Tensor& self, IntList dims) { | ||
|
|
||
| // TODO: support non-contiguous tensors | ||
| auto in_t = self.contiguous(); | ||
|
|
||
| int64_t flip_dims_size = dims.size(), total_dims = in_t.dim(), N = in_t.numel(); | ||
|
|
||
| // check if number of axis in dim is valid | ||
| if (flip_dims_size == 0) { | ||
| std::stringstream ss; | ||
| ss << "expected input tensor dims not empty, " | ||
| << "but got tensor dims size=" << flip_dims_size; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check duplicates in dims | ||
| auto flip_dims_v = std::vector<int64_t>(dims); | ||
| flip_dims_v.erase(std::unique(flip_dims_v.begin(), flip_dims_v.end()), flip_dims_v.end()); | ||
| if ((int64_t)flip_dims_v.size() < flip_dims_size) { | ||
| std::stringstream ss; | ||
| ss << "dims has duplicates, " | ||
| << "original flip dims size=" << flip_dims_size << ", " | ||
| << "but unique flip dims size= " << flip_dims_v.size(); | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check len of dims | ||
| if (flip_dims_size > total_dims) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims size <= tensor total dims, " | ||
| << "but got flip dims size=" << flip_dims_size << " and " | ||
| << "tensor total dim=" << total_dims; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| // check if dims axis within range | ||
| int64_t min_d = total_dims, max_d = 0; | ||
| for (auto d : dims) { | ||
| min_d = std::min(min_d, d); | ||
| max_d = std::max(max_d, d); | ||
| } | ||
|
|
||
| if (min_d < 0) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims axis >= 0, " | ||
| << "but got min flip dims=" << min_d; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| if (max_d >= total_dims) { | ||
| std::stringstream ss; | ||
| ss << "expected flip dims axis < tensor total dims, " | ||
| << "but got max flip dims=" << max_d << " and " | ||
| << "tensor total dim=" << total_dims; | ||
| throw std::runtime_error(ss.str()); | ||
| } | ||
|
|
||
| auto flip_dims = std::vector<int64_t>(dims); | ||
| auto flip_dims_t = at::CPU(kLong).tensorFromBlob(flip_dims.data(), {static_cast<int64_t>(flip_dims.size())}); | ||
|
|
||
| auto shape = std::vector<int64_t>(in_t.sizes()); | ||
| auto shape_t = at::CPU(kLong).tensorFromBlob(shape.data(), {static_cast<int64_t>(shape.size())}); | ||
|
|
||
| auto strides = std::vector<int64_t>(in_t.strides()); | ||
| auto strides_t = at::CPU(kLong).tensorFromBlob(strides.data(), {static_cast<int64_t>(strides.size())}); | ||
|
|
||
| auto indices = at::zeros(CUDA(kLong), {N, total_dims}); | ||
| auto out_t = at::zeros_like(in_t); | ||
|
|
||
| int64_t block_size = 512; | ||
| dim3 dim_block(block_size); | ||
| dim3 dim_grid((N + block_size - 1) / block_size); | ||
|
|
||
| AT_DISPATCH_ALL_TYPES_AND_HALF(in_t.type(), "flip_cuda", [&] { | ||
| using cuda_scalar_t = cuda::type<scalar_t>; | ||
| flip_cuda_kernel<<<dim_grid, dim_block, 0, globalContext().getCurrentCUDAStream()>>>( | ||
| in_t.data<cuda_scalar_t>(), out_t.data<cuda_scalar_t>(), N, flip_dims_t.toType(CUDA(kLong)).data<int64_t>(), indices.data<int64_t>(), flip_dims_size, strides_t.toType(CUDA(kLong)).data<int64_t>(), shape_t.toType(CUDA(kLong)).data<int64_t>(), total_dims); | ||
| }); | ||
|
|
||
| return out_t; | ||
| } | ||
|
|
||
| }} // namespace at::native |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This comment was marked as off-topic.
Sorry, something went wrong.
Uh oh!
There was an error while loading. Please reload this page.