Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 49 additions & 39 deletions aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/PinnedMemoryAllocator.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>

#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/cuda/MiscUtils.h>
Expand Down Expand Up @@ -795,59 +796,75 @@ std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool p

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename scalar_t, bool upper>
template <typename scalar_t, typename IndexType, bool upper>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__
void triu_tril_kernel(
scalar_t* result, scalar_t* self, int64_t k, int64_t N,
int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride,
int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol) {
cuda::detail::TensorInfo<scalar_t, IndexType> result_info,
const cuda::detail::TensorInfo<scalar_t, IndexType> self_info,
const int64_t k, const int64_t N) {
int64_t linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (linear_idx >= N) {
return;
}

int64_t self_batch_idx = blockIdx.y;
int64_t row = linear_idx / self_ncol;
int64_t col = linear_idx % self_ncol;
auto dims = self_info.dims;

IndexType self_offset = 0, result_offset = 0;
// Compute column index and corresponding offset
IndexType col = linear_idx % self_info.sizes[dims - 1];
linear_idx /= self_info.sizes[dims - 1];
self_offset += self_info.strides[dims - 1] * col;
result_offset += result_info.strides[dims - 1] * col;

// Compute row index and corresponding offset
IndexType row = linear_idx % self_info.sizes[dims - 2];
linear_idx /= self_info.sizes[dims - 2];
self_offset += self_info.strides[dims - 2] * row;
result_offset += result_info.strides[dims - 2] * row;

// Compute remaining offsets
IndexType running_index;
#pragma unroll
for (IndexType i = dims - 3; i >= 0; --i) {
running_index = linear_idx % self_info.sizes[i];
linear_idx /= self_info.sizes[i];
self_offset += running_index * self_info.strides[i];
result_offset += running_index * result_info.strides[i];
}

bool mask = upper ? (col - row >= k) : (col - row <= k);

// Now compute the offset for the self and result tensor
int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride;
int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride;
result[res_offset] = mask ? self[self_offset] : scalar_t(0);
result_info.data[result_offset] = mask ? self_info.data[self_offset] : scalar_t(0);
}

template <bool upper>
Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, const char* name) {
int64_t n_batches = batchCount(self), mat_size = self.size(-1) * self.size(-2),
res_batch_stride = result.dim() > 2 ? result.stride(-3) : 1,
res_row_stride = result.stride(-2), res_col_stride = result.stride(-1),
self_batch_stride = self.dim() > 2 ? self.stride(-3) : 1,
self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
int64_t N = self.numel();
dim3 dim_block = cuda::getApplyBlock();
dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
dim3 dim_grid((N + dim_block.x - 1) / dim_block.x);
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, self.scalar_type(), name, [&]{
triu_tril_kernel<scalar_t, upper>
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
res_batch_stride, res_row_stride, res_col_stride,
self_batch_stride, self_row_stride, self_col_stride, self.size(-1));
if (cuda::detail::canUse32BitIndexMath(result) && cuda::detail::canUse32BitIndexMath(self)) {
auto result_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(result);
auto self_info = cuda::detail::getTensorInfo<scalar_t, int32_t>(self);
triu_tril_kernel<scalar_t, int32_t, upper>
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
result_info, self_info, k, N);
} else {
auto result_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(result);
auto self_info = cuda::detail::getTensorInfo<scalar_t, int64_t>(self);
triu_tril_kernel<scalar_t, int64_t, upper>
<<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
result_info, self_info, k, N);
}
});
AT_CUDA_CHECK(cudaGetLastError());
return result;
}

Tensor& tril_cuda_(Tensor &self, int64_t k) {
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
tril_cuda_out(result, self_c, k);
if (!inplace) self.copy_(result);
return self;
return tril_cuda_out(self, self, k);
}

Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
Expand All @@ -857,17 +874,11 @@ Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
if (self.numel() == 0) {
return result;
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
return triu_tril_cuda_template<false>(result, self_c, k, "tril");
return triu_tril_cuda_template<false>(result, self, k, "tril");
}

Tensor& triu_cuda_(Tensor &self, int64_t k) {
bool inplace = checkTrilTriuBatchContiguous(self);
Tensor self_c = inplace ? self : self.contiguous();
Tensor result = inplace ? self : at::empty_like(self);
triu_cuda_out(result, self_c, k);
if (!inplace) self.copy_(result);
return self;
return triu_cuda_out(self, self, k);
}

Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
Expand All @@ -877,8 +888,7 @@ Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
if (self.numel() == 0) {
return result;
}
Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
return triu_tril_cuda_template<true>(result, self_c, k, "triu");
return triu_tril_cuda_template<true>(result, self, k, "triu");
}

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down