Skip to content

Commit c52a61f

Browse files
authored
Performance fix for torch.cat operator on ROCm (#46097) (#46323)
Summary: This pull request is a partial revert of #44833 for ROCm to fix the performance of the concatenate operator. The changes only affect execution on ROCm and are guarded by the define `__HIP_PLATFORM_HCC__` Pull Request resolved: #46097 Test Plan: Benchmark `python -m pt.cat_test --tag_filter all --device cuda` Results on ROCm before the PR: ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 10828.314 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 11888.028 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 11898.945 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 11787.744 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 11792.479 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 11769.718 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f989e5c2510>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f989e5c2510>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 11633.882 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7f989e5c2620>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7f989e5c2620>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 11617.768 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7f96eee4df28>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7f96eee4df28>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 11625.143 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f96ef874048>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f96ef874048>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 13079.204 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7f96ef8740d0>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7f96ef8740d0>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 13095.620 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7f96ef874158>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7f96ef874158>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 13403.086 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 118.704 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 263.273 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 463.024 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f96ef8741e0>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f96ef8741e0>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 23818.032 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f96ef874268>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f96ef874268>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 234778.296 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f96ef8742f0>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f96ef8742f0>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 470288.132 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f96ef874378>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f96ef874378>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 704361.221 ``` Results on ROCm after the PR: ``` # ---------------------------------------- # PyTorch/Caffe2 Operator Micro-benchmarks # ---------------------------------------- # Tag : all # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1,1,1)_N2_dim0_cuda # Input: sizes: (1, 1, 1), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 29.292 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(512,512,2)_N2_dim1_cuda # Input: sizes: (512, 512, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 46.320 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(128,1024,2)_N2_dim1_cuda # Input: sizes: (128, 1024, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 36.969 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim0_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 0, device: cuda Forward Execution Time (us) : 92.816 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1025,1023,2)_N2_dim1_cuda # Input: sizes: (1025, 1023, 2), N: 2, dim: 1, device: cuda Forward Execution Time (us) : 93.943 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(1024,1024,2)_N2_dim2_cuda # Input: sizes: (1024, 1024, 2), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 163.914 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1da3186510>,111,65]_N5_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1da3186510>, 111, 65], N: 5, dim: 0, device: cuda Forward Execution Time (us) : 75.475 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[96,<function<lambda>at0x7f1da3186620>,64]_N5_dim1_cuda # Input: sizes: [96, <function <lambda> at 0x7f1da3186620>, 64], N: 5, dim: 1, device: cuda Forward Execution Time (us) : 68.880 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[128,64,<function<lambda>at0x7f1bf3c50f28>]_N5_dim2_cuda # Input: sizes: [128, 64, <function <lambda> at 0x7f1bf3c50f28>], N: 5, dim: 2, device: cuda Forward Execution Time (us) : 85.268 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bf4669048>,32,64]_N50_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bf4669048>, 32, 64], N: 50, dim: 0, device: cuda Forward Execution Time (us) : 111.543 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[32,<function<lambda>at0x7f1bf46690d0>,64]_N50_dim1_cuda # Input: sizes: [32, <function <lambda> at 0x7f1bf46690d0>, 64], N: 50, dim: 1, device: cuda Forward Execution Time (us) : 110.644 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[33,65,<function<lambda>at0x7f1bf4669158>]_N50_dim2_cuda # Input: sizes: [33, 65, <function <lambda> at 0x7f1bf4669158>], N: 50, dim: 2, device: cuda Forward Execution Time (us) : 116.201 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(64,32,4,16,32)_N2_dim2_cuda # Input: sizes: (64, 32, 4, 16, 32), N: 2, dim: 2, device: cuda Forward Execution Time (us) : 117.708 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(16,32,4,16,32)_N8_dim2_cuda # Input: sizes: (16, 32, 4, 16, 32), N: 8, dim: 2, device: cuda Forward Execution Time (us) : 264.953 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes(9,31,5,15,33)_N17_dim4_cuda # Input: sizes: (9, 31, 5, 15, 33), N: 17, dim: 4, device: cuda Forward Execution Time (us) : 480.304 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bf46691e0>]_N100_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bf46691e0>], N: 100, dim: 0, device: cuda Forward Execution Time (us) : 116.385 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bf4669268>]_N1000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bf4669268>], N: 1000, dim: 0, device: cuda Forward Execution Time (us) : 913.591 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bf46692f0>]_N2000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bf46692f0>], N: 2000, dim: 0, device: cuda Forward Execution Time (us) : 2003.212 # Benchmarking PyTorch: cat # Mode: Eager # Name: cat_sizes[<function<lambda>at0x7f1bf4669378>]_N3000_dim0_cuda # Input: sizes: [<function <lambda> at 0x7f1bf4669378>], N: 3000, dim: 0, device: cuda Forward Execution Time (us) : 3004.174 ``` Reviewed By: bdhirsh Differential Revision: D24286324 Pulled By: malfet fbshipit-source-id: 291f3f3f80f9d2f9ba52a455a942f3fb0406e7d2
1 parent 1c28571 commit c52a61f

File tree

1 file changed

+168
-5
lines changed

1 file changed

+168
-5
lines changed

aten/src/ATen/native/cuda/Shape.cu

Lines changed: 168 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
namespace at {
1313
namespace native {
1414

15+
#ifdef __HIP_PLATFORM_HCC__
16+
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
17+
#else
1518
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
19+
#endif
1620
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;
1721

1822
namespace {
@@ -78,6 +82,46 @@ struct OutputTensorSizeStride {
7882
* The most important assumption made is that the input tensors are contiguous.
7983
*/
8084

85+
86+
// Use pinned memory and and pass the struct by pointer on ROCm
87+
template <typename T, typename IndexType>
88+
struct CatArrInputTensor {
89+
T* input;
90+
IndexType offset;
91+
IndexType dimSize;
92+
IndexType nElements;
93+
};
94+
95+
template <typename T, typename IndexType, int Dims>
96+
C10_LAUNCH_BOUNDS_1(512)
97+
__global__ void HIP_CatArrayBatchedCopy(
98+
T* output,
99+
CatArrInputTensor<T, IndexType>* inputs,
100+
OutputTensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
101+
const int concatDim,
102+
IndexType dimStride) {
103+
104+
IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
105+
IndexType nElements = inputs[blockIdx.y].nElements;
106+
107+
if(tid >= nElements) return;
108+
109+
T* data = inputs[blockIdx.y].input;
110+
IndexType offset = inputs[blockIdx.y].offset;
111+
IndexType dimSize = inputs[blockIdx.y].dimSize;
112+
IndexType dataOffset = offset * dimStride;
113+
114+
IndexType stride = gridDim.x * blockDim.x;
115+
116+
while( tid < nElements){
117+
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
118+
os.outputSize, os.outputStride, dimSize, concatDim, tid);
119+
output[dataOffset + elementOffset] = data[tid];
120+
121+
tid += stride;
122+
}
123+
}
124+
81125
// pass meta data directly through kernel argument instead of pin memory
82126
template <typename T, typename IndexType, int n>
83127
struct CatArrInputTensorMetadata {
@@ -88,9 +132,6 @@ struct CatArrInputTensorMetadata {
88132
};
89133

90134
template <typename T, typename IndexType, int Dims>
91-
#ifdef __HIP_PLATFORM_HCC__
92-
C10_LAUNCH_BOUNDS_1(512)
93-
#endif
94135
__global__ void CatArrayBatchedCopy(
95136
T* output,
96137
CatArrInputTensorMetadata<T, IndexType, CAT_ARRAY_BATCH_SIZE> inputs,
@@ -141,6 +182,122 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
141182
}
142183
}
143184

185+
template <typename scalar_t>
186+
void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
187+
int nDims, c10::MemoryFormat memory_format) {
188+
// First, let's set up our kernel parameters. We start with a raw pointer to
189+
// the storage for the output Tensor.
190+
scalar_t *data = out.data_ptr<scalar_t>();
191+
192+
// Kernel Parameter
193+
long tensorMetadataSize =
194+
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
195+
auto d_inputs_storage = at::empty(
196+
{tensorMetadataSize}, out.options().dtype(at::kByte));
197+
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
198+
d_inputs_storage.data_ptr());
199+
200+
OutputTensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> param;
201+
202+
// Next, let's initialize the size, stride arrays for the output Tensor.
203+
if (memory_format == c10::MemoryFormat::Contiguous) {
204+
for (int i = 0; i < nDims; ++i) {
205+
param.outputSize[i] = at::native::size(out, i);
206+
param.outputStride[i] = out.stride(i);
207+
}
208+
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
209+
// permute the semantics of dims from NCHW to NHWC so that the input
210+
// tensor is now contiguous
211+
param.outputSize[0] = at::native::size(out, 0);
212+
param.outputStride[0] = out.stride(0);
213+
for (int i = 1; i < nDims - 1; ++i) {
214+
param.outputSize[i] = at::native::size(out, i + 1);
215+
param.outputStride[i] = out.stride(i + 1);
216+
}
217+
param.outputSize[nDims - 1] = at::native::size(out, 1);
218+
param.outputStride[nDims - 1] = out.stride(1);
219+
} else {
220+
TORCH_CHECK(false, "unsupported memory format");
221+
}
222+
223+
at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();
224+
225+
// Now we loop
226+
int batchCounter = 0;
227+
int64_t offset = 0;
228+
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
229+
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
230+
{
231+
auto stackInputs_storage = at::empty({tensorMetadataSize},
232+
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
233+
auto stackInputs =
234+
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
235+
stackInputs_storage.data_ptr());
236+
for (batchCounter = 0;
237+
batchCounter < CAT_ARRAY_BATCH_SIZE &&
238+
(i+batchCounter) < inputs.size();
239+
++batchCounter) {
240+
int64_t dimSize = at::native::size(inputs[i+batchCounter], dimension);
241+
242+
stackInputs[batchCounter].input =
243+
inputs[i+batchCounter].data_ptr<scalar_t>();
244+
stackInputs[batchCounter].offset = offset;
245+
stackInputs[batchCounter].dimSize = dimSize;
246+
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();
247+
248+
// update offset
249+
offset += dimSize;
250+
}
251+
at::native::copy_(d_inputs_storage, stackInputs_storage,
252+
/* non_blocking= */ true);
253+
}
254+
255+
// Next, let's consider how we set our kernel launch parameters.
256+
// We borrow from THCApply, which the kernel's internal indexing
257+
// is based on.
258+
dim3 applyBlock = dim3(32*16);
259+
260+
//Get grid where x dim fills half gpu and y dim is number of tensors.
261+
//This will have cating two tensors fill the entire grid, but prevent
262+
//many threads from needlessly load meta data if their sizes is small.
263+
dim3 catGrid;
264+
getCatGrid(batchCounter, catGrid);
265+
266+
if (memory_format != c10::MemoryFormat::Contiguous) {
267+
switch (dimension) {
268+
case 0:
269+
break;
270+
case 1:
271+
dimension = nDims - dimension;
272+
break;
273+
default:
274+
dimension--;
275+
}
276+
}
277+
// Template Declarations for dim = 1, 2, 3, 4
278+
#define HANDLE_CASE(DIMS) \
279+
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
280+
catGrid, applyBlock, 0, stream.stream()>>>(\
281+
data, d_inputs, param, dimension, param.outputStride[dimension]);
282+
switch (nDims) {
283+
case 1:
284+
HANDLE_CASE(1);
285+
break;
286+
case 2:
287+
HANDLE_CASE(2);
288+
break;
289+
case 3:
290+
HANDLE_CASE(3);
291+
break;
292+
case 4:
293+
HANDLE_CASE(4);
294+
break;
295+
}
296+
#undef HANDLE_CASE
297+
AT_CUDA_CHECK(cudaGetLastError());
298+
}
299+
}
300+
144301
template <typename scalar_t>
145302
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
146303
int nDims, c10::MemoryFormat memory_format) {
@@ -235,7 +392,6 @@ void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
235392
AT_CUDA_CHECK(cudaGetLastError());
236393
}
237394
}
238-
239395
} // namespace
240396

241397
Tensor cat_cuda(TensorList inputs, int64_t dimension) {
@@ -373,12 +529,19 @@ Tensor& cat_out_cuda(Tensor& out, TensorList inputs, int64_t dimension) {
373529
all32BitIndexable &&
374530
allSameType) {
375531

532+
#ifdef __HIP_PLATFORM_HCC__
533+
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
534+
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
535+
out.scalar_type(), "cat_cuda", [&]() {
536+
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
537+
});
538+
#else
376539
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
377540
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
378541
out.scalar_type(), "cat_cuda", [&]() {
379542
parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
380543
});
381-
544+
#endif
382545
} else {
383546
int64_t offset = 0;
384547
for (int j = 0; j < inputs.size(); j++)

0 commit comments

Comments
 (0)