Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 168 additions & 5 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
namespace at {
namespace native {

#ifdef __HIP_PLATFORM_HCC__
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
#else
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
#endif
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;

namespace {
Expand Down Expand Up @@ -78,6 +82,46 @@ struct OutputTensorSizeStride {
* The most important assumption made is that the input tensors are contiguous.
*/


// Use pinned memory and and pass the struct by pointer on ROCm
template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};

template <typename T, typename IndexType, int Dims>
C10_LAUNCH_BOUNDS_1(512)
__global__ void HIP_CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs[blockIdx.y].nElements;

if(tid >= nElements) return;

T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType dataOffset = offset * dimStride;

IndexType stride = gridDim.x * blockDim.x;

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.outputSize, os.outputStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

tid += stride;
}
}

// pass meta data directly through kernel argument instead of pin memory
template <typename T, typename IndexType, int n>
struct CatArrInputTensorMetadata {
Expand All @@ -88,9 +132,6 @@ struct CatArrInputTensorMetadata {
};

template <typename T, typename IndexType, int Dims>
#ifdef __HIP_PLATFORM_HCC__
C10_LAUNCH_BOUNDS_1(512)
#endif
__global__ void CatArrayBatchedCopy(
T* output,
CatArrInputTensorMetadata<T, IndexType, CAT_ARRAY_BATCH_SIZE> inputs,
Expand Down Expand Up @@ -141,6 +182,122 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
}
}

template <typename scalar_t>
void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();

// Kernel Parameter
long tensorMetadataSize =
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
auto d_inputs_storage = at::empty(
{tensorMetadataSize}, out.options().dtype(at::kByte));
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());

OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
param.outputSize[i] = at::native::size(out, i);
param.outputStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
param.outputSize[0] = at::native::size(out, 0);
param.outputStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
param.outputSize[i] = at::native::size(out, i + 1);
param.outputStride[i] = out.stride(i + 1);
}
param.outputSize[nDims - 1] = at::native::size(out, 1);
param.outputStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
{
auto stackInputs_storage = at::empty({tensorMetadataSize},
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto stackInputs =
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
stackInputs_storage.data_ptr());
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);

stackInputs[batchCounter].input =
inputs[i+batchCounter].data_ptr<scalar_t>();
stackInputs[batchCounter].offset = offset;
stackInputs[batchCounter].dimSize = dimSize;
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();

// update offset
offset += dimSize;
}
at::native::copy_(d_inputs_storage, stackInputs_storage,
/* non_blocking= */ true);
}

// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
// is based on.
dim3 applyBlock = dim3(32*16);

//Get grid where x dim fills half gpu and y dim is number of tensors.
//This will have cating two tensors fill the entire grid, but prevent
//many threads from needlessly load meta data if their sizes is small.
dim3 catGrid;
getCatGrid(batchCounter, catGrid);

if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
case 0:
break;
case 1:
dimension = nDims - dimension;
break;
default:
dimension--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, param, dimension, param.outputStride[dimension]);
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;
case 2:
HANDLE_CASE(2);
break;
case 3:
HANDLE_CASE(3);
break;
case 4:
HANDLE_CASE(4);
break;
}
#undef HANDLE_CASE
AT_CUDA_CHECK(cudaGetLastError());
}
}

template <typename scalar_t>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
Expand Down Expand Up @@ -235,7 +392,6 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
AT_CUDA_CHECK(cudaGetLastError());
}
}

} // namespace

Tensor cat_cuda(TensorList inputs, int64_t dimension) {
Expand Down Expand Up @@ -373,12 +529,19 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
all32BitIndexable &&
allSameType) {

#ifdef __HIP_PLATFORM_HCC__
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
#else
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});

#endif
} else {
int64_t offset = 0;
for (int j = 0; j < inputs.size(); j++)
Expand Down