44#include < ATen/NativeFunctions.h>
55#include < ATen/cuda/PinnedMemoryAllocator.h>
66#include < ATen/cuda/CUDAApplyUtils.cuh>
7+ #include < ATen/cuda/detail/IndexUtils.cuh>
78
89#include < ATen/native/LinearAlgebraUtils.h>
910#include < ATen/native/cuda/MiscUtils.h>
@@ -795,59 +796,75 @@ std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool p
795796
796797// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
797798
798- template <typename scalar_t , bool upper>
799+ template <typename scalar_t , typename IndexType, bool upper>
799800#ifdef __HIP_PLATFORM_HCC__
800801C10_LAUNCH_BOUNDS_1 (512 )
801802#endif
802803__global__
803804void triu_tril_kernel (
804- scalar_t * result, scalar_t * self, int64_t k, int64_t N ,
805- int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride ,
806- int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol ) {
805+ cuda::detail::TensorInfo< scalar_t , IndexType> result_info ,
806+ const cuda::detail::TensorInfo< scalar_t , IndexType> self_info ,
807+ const int64_t k, const int64_t N ) {
807808 int64_t linear_idx = blockIdx .x * blockDim .x + threadIdx .x ;
808809 if (linear_idx >= N) {
809810 return ;
810811 }
811812
812- int64_t self_batch_idx = blockIdx .y ;
813- int64_t row = linear_idx / self_ncol;
814- int64_t col = linear_idx % self_ncol;
813+ auto dims = self_info.dims ;
814+
815+ IndexType self_offset = 0 , result_offset = 0 ;
816+ // Compute column index and corresponding offset
817+ IndexType col = linear_idx % self_info.sizes [dims - 1 ];
818+ linear_idx /= self_info.sizes [dims - 1 ];
819+ self_offset += self_info.strides [dims - 1 ] * col;
820+ result_offset += result_info.strides [dims - 1 ] * col;
821+
822+ // Compute row index and corresponding offset
823+ IndexType row = linear_idx % self_info.sizes [dims - 2 ];
824+ linear_idx /= self_info.sizes [dims - 2 ];
825+ self_offset += self_info.strides [dims - 2 ] * row;
826+ result_offset += result_info.strides [dims - 2 ] * row;
827+
828+ // Compute remaining offsets
829+ IndexType running_index;
830+ #pragma unroll
831+ for (IndexType i = dims - 3 ; i >= 0 ; --i) {
832+ running_index = linear_idx % self_info.sizes [i];
833+ linear_idx /= self_info.sizes [i];
834+ self_offset += running_index * self_info.strides [i];
835+ result_offset += running_index * result_info.strides [i];
836+ }
815837
816838 bool mask = upper ? (col - row >= k) : (col - row <= k);
817-
818- // Now compute the offset for the self and result tensor
819- int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride;
820- int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride;
821- result[res_offset] = mask ? self[self_offset] : scalar_t (0 );
839+ result_info.data [result_offset] = mask ? self_info.data [self_offset] : scalar_t (0 );
822840}
823841
824842template <bool upper>
825843Tensor& triu_tril_cuda_template (Tensor& result, const Tensor& self, int64_t k, const char * name) {
826- int64_t n_batches = batchCount (self), mat_size = self.size (-1 ) * self.size (-2 ),
827- res_batch_stride = result.dim () > 2 ? result.stride (-3 ) : 1 ,
828- res_row_stride = result.stride (-2 ), res_col_stride = result.stride (-1 ),
829- self_batch_stride = self.dim () > 2 ? self.stride (-3 ) : 1 ,
830- self_row_stride = self.stride (-2 ), self_col_stride = self.stride (-1 );
844+ int64_t N = self.numel ();
831845 dim3 dim_block = cuda::getApplyBlock ();
832- dim3 dim_grid ((mat_size + dim_block.x - 1 ) / dim_block.x , n_batches );
846+ dim3 dim_grid ((N + dim_block.x - 1 ) / dim_block.x );
833847 AT_DISPATCH_ALL_TYPES_AND (at::ScalarType::Half, self.scalar_type (), name, [&]{
834- triu_tril_kernel<scalar_t , upper>
835- <<<dim_grid, dim_block, 0 , at::cuda::getCurrentCUDAStream()>>> (
836- result.data <scalar_t >(), self.data <scalar_t >(), k, mat_size,
837- res_batch_stride, res_row_stride, res_col_stride,
838- self_batch_stride, self_row_stride, self_col_stride, self.size (-1 ));
848+ if (cuda::detail::canUse32BitIndexMath (result) && cuda::detail::canUse32BitIndexMath (self)) {
849+ auto result_info = cuda::detail::getTensorInfo<scalar_t , int32_t >(result);
850+ auto self_info = cuda::detail::getTensorInfo<scalar_t , int32_t >(self);
851+ triu_tril_kernel<scalar_t , int32_t , upper>
852+ <<<dim_grid, dim_block, 0 , at::cuda::getCurrentCUDAStream()>>> (
853+ result_info, self_info, k, N);
854+ } else {
855+ auto result_info = cuda::detail::getTensorInfo<scalar_t , int64_t >(result);
856+ auto self_info = cuda::detail::getTensorInfo<scalar_t , int64_t >(self);
857+ triu_tril_kernel<scalar_t , int64_t , upper>
858+ <<<dim_grid, dim_block, 0 , at::cuda::getCurrentCUDAStream()>>> (
859+ result_info, self_info, k, N);
860+ }
839861 });
840862 AT_CUDA_CHECK (cudaGetLastError ());
841863 return result;
842864}
843865
844866Tensor& tril_cuda_ (Tensor &self, int64_t k) {
845- bool inplace = checkTrilTriuBatchContiguous (self);
846- Tensor self_c = inplace ? self : self.contiguous ();
847- Tensor result = inplace ? self : at::empty_like (self);
848- tril_cuda_out (result, self_c, k);
849- if (!inplace) self.copy_ (result);
850- return self;
867+ return tril_cuda_out (self, self, k);
851868}
852869
853870Tensor& tril_cuda_out (Tensor &result, const Tensor& self, int64_t k) {
@@ -857,17 +874,11 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
857874 if (self.numel () == 0 ) {
858875 return result;
859876 }
860- Tensor self_c = checkTrilTriuBatchContiguous (self) ? self : self.contiguous ();
861- return triu_tril_cuda_template<false >(result, self_c, k, " tril" );
877+ return triu_tril_cuda_template<false >(result, self, k, " tril" );
862878}
863879
864880Tensor& triu_cuda_ (Tensor &self, int64_t k) {
865- bool inplace = checkTrilTriuBatchContiguous (self);
866- Tensor self_c = inplace ? self : self.contiguous ();
867- Tensor result = inplace ? self : at::empty_like (self);
868- triu_cuda_out (result, self_c, k);
869- if (!inplace) self.copy_ (result);
870- return self;
881+ return triu_cuda_out (self, self, k);
871882}
872883
873884Tensor& triu_cuda_out (Tensor &result, const Tensor& self, int64_t k) {
@@ -877,8 +888,7 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
877888 if (self.numel () == 0 ) {
878889 return result;
879890 }
880- Tensor self_c = checkTrilTriuBatchContiguous (self) ? self : self.contiguous ();
881- return triu_tril_cuda_template<true >(result, self_c, k, " triu" );
891+ return triu_tril_cuda_template<true >(result, self, k, " triu" );
882892}
883893
884894// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
0 commit comments