Skip to content

Commit eb4d43d

Browse files
vishwakftwfacebook-github-bot
authored andcommitted
Make CUDA triu / tril support batches of size > 65535 (#21067)
Summary: In the previous implementation of triu / tril, we passed the batch size in the 2nd dimension of a grid. This is limited to 65535, which means that performing triu / tril on a tensor with batch size > 65535 will throw an error. This PR removes the dependence on the 2nd dimension, and corresponding non-contiguity constraints. Changelog: - Compute offset, row and col in the kernel - Use 1st dimension of grid alone - Remove unnecessary contiguity checks on tensors as a result of this change. Pull Request resolved: #21067 Differential Revision: D15572501 Pulled By: ezyang fbshipit-source-id: 93851cb661918ce794d43eeb12c8a38762e1358c
1 parent 057ddab commit eb4d43d

File tree

1 file changed

+49
-39
lines changed

1 file changed

+49
-39
lines changed

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 49 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <ATen/NativeFunctions.h>
55
#include <ATen/cuda/PinnedMemoryAllocator.h>
66
#include <ATen/cuda/CUDAApplyUtils.cuh>
7+
#include <ATen/cuda/detail/IndexUtils.cuh>
78

89
#include <ATen/native/LinearAlgebraUtils.h>
910
#include <ATen/native/cuda/MiscUtils.h>
@@ -795,59 +796,75 @@ std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool p
795796

796797
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
797798

798-
template <typename scalar_t, bool upper>
799+
template <typename scalar_t, typename IndexType, bool upper>
799800
#ifdef __HIP_PLATFORM_HCC__
800801
C10_LAUNCH_BOUNDS_1(512)
801802
#endif
802803
__global__
803804
void triu_tril_kernel(
804-
scalar_t* result, scalar_t* self, int64_t k, int64_t N,
805-
int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride,
806-
int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol) {
805+
cuda::detail::TensorInfo<scalar_t, IndexType> result_info,
806+
const cuda::detail::TensorInfo<scalar_t, IndexType> self_info,
807+
const int64_t k, const int64_t N) {
807808
int64_t linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
808809
if (linear_idx >= N) {
809810
return;
810811
}
811812

812-
int64_t self_batch_idx = blockIdx.y;
813-
int64_t row = linear_idx / self_ncol;
814-
int64_t col = linear_idx % self_ncol;
813+
auto dims = self_info.dims;
814+
815+
IndexType self_offset = 0, result_offset = 0;
816+
// Compute column index and corresponding offset
817+
IndexType col = linear_idx % self_info.sizes[dims - 1];
818+
linear_idx /= self_info.sizes[dims - 1];
819+
self_offset += self_info.strides[dims - 1] * col;
820+
result_offset += result_info.strides[dims - 1] * col;
821+
822+
// Compute row index and corresponding offset
823+
IndexType row = linear_idx % self_info.sizes[dims - 2];
824+
linear_idx /= self_info.sizes[dims - 2];
825+
self_offset += self_info.strides[dims - 2] * row;
826+
result_offset += result_info.strides[dims - 2] * row;
827+
828+
// Compute remaining offsets
829+
IndexType running_index;
830+
#pragma unroll
831+
for (IndexType i = dims - 3; i >= 0; --i) {
832+
running_index = linear_idx % self_info.sizes[i];
833+
linear_idx /= self_info.sizes[i];
834+
self_offset += running_index * self_info.strides[i];
835+
result_offset += running_index * result_info.strides[i];
836+
}
815837

816838
bool mask = upper ? (col - row >= k) : (col - row <= k);
817-
818-
// Now compute the offset for the self and result tensor
819-
int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride;
820-
int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride;
821-
result[res_offset] = mask ? self[self_offset] : scalar_t(0);
839+
result_info.data[result_offset] = mask ? self_info.data[self_offset] : scalar_t(0);
822840
}
823841

824842
template <bool upper>
825843
Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, const char* name) {
826-
int64_t n_batches = batchCount(self), mat_size = self.size(-1) * self.size(-2),
827-
res_batch_stride = result.dim() > 2 ? result.stride(-3) : 1,
828-
res_row_stride = result.stride(-2), res_col_stride = result.stride(-1),
829-
self_batch_stride = self.dim() > 2 ? self.stride(-3) : 1,
830-
self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
844+
int64_t N = self.numel();
831845
dim3 dim_block = cuda::getApplyBlock();
832-
dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
846+
dim3 dim_grid((N + dim_block.x - 1) / dim_block.x);
833847
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), name, [&]{
834-
triu_tril_kernel<scalar_t, upper>
835-
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
836-
result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
837-
res_batch_stride, res_row_stride, res_col_stride,
838-
self_batch_stride, self_row_stride, self_col_stride, self.size(-1));
848+
if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) {
849+
auto result_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(result);
850+
auto self_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(self);
851+
triu_tril_kernel<scalar_t, int32_t, upper>
852+
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
853+
result_info, self_info, k, N);
854+
} else {
855+
auto result_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(result);
856+
auto self_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(self);
857+
triu_tril_kernel<scalar_t, int64_t, upper>
858+
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
859+
result_info, self_info, k, N);
860+
}
839861
});
840862
AT_CUDA_CHECK(cudaGetLastError());
841863
return result;
842864
}
843865

844866
Tensor& tril_cuda_(Tensor &self, int64_t k) {
845-
bool inplace = checkTrilTriuBatchContiguous(self);
846-
Tensor self_c = inplace ? self : self.contiguous();
847-
Tensor result = inplace ? self : at::empty_like(self);
848-
tril_cuda_out(result, self_c, k);
849-
if (!inplace) self.copy_(result);
850-
return self;
867+
return tril_cuda_out(self, self, k);
851868
}
852869

853870
Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
@@ -857,17 +874,11 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
857874
if (self.numel() == 0) {
858875
return result;
859876
}
860-
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
861-
return triu_tril_cuda_template<false>(result, self_c, k, "tril");
877+
return triu_tril_cuda_template<false>(result, self, k, "tril");
862878
}
863879

864880
Tensor& triu_cuda_(Tensor &self, int64_t k) {
865-
bool inplace = checkTrilTriuBatchContiguous(self);
866-
Tensor self_c = inplace ? self : self.contiguous();
867-
Tensor result = inplace ? self : at::empty_like(self);
868-
triu_cuda_out(result, self_c, k);
869-
if (!inplace) self.copy_(result);
870-
return self;
881+
return triu_cuda_out(self, self, k);
871882
}
872883

873884
Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
@@ -877,8 +888,7 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
877888
if (self.numel() == 0) {
878889
return result;
879890
}
880-
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
881-
return triu_tril_cuda_template<true>(result, self_c, k, "triu");
891+
return triu_tril_cuda_template<true>(result, self, k, "triu");
882892
}
883893

884894
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

0 commit comments

Comments
 (0)