Skip to content

Commit 76c08bc

Browse files
committed
Update base for Update on "Improved perfs for vectorized bilinear interpolate cpu uint8 RGB-case (channels last)"
## Description - Based on #96651 - Improved perfs for vectorized **bilinear** interpolate uint8 RGB-case, **channels last** - unified RGB and RGBA processing code such that RGB input is not copied into RGBA - Performances are more close to Pillow-SIMD (labeled as `Pillow (9.0.0.post1)` in the results) - RGBA case perfs are the same after refactoring (see Source link below) - Fixed mem pointer alignment, added more comments (reviews from #96651) ## Results - `Pillow (9.0.0.post1)` == Pillow-SIMD ``` [-------------------------------------------------------------------------------------------------- Resize -------------------------------------------------------------------------------------------------] | Pillow (9.0.0.post1) | torch (2.1.0a0+gitce4be01) PR | torch (2.1.0a0+git5309c44) nightly | Speed-up: PR vs nightly 1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=True | 38.548 (+-0.280) | 57.536 (+-0.210) | 132.147 (+-1.236) | 2.297 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=False | | 38.532 (+-0.219) | 111.789 (+-1.175) | 2.901 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=True | 127.689 (+-1.348) | 156.262 (+-1.213) | 302.518 (+-2.632) | 1.936 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=False | | 145.483 (+-1.077) | 286.663 (+-2.494) | 1.970 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=True | 178.117 (+-1.956) | 215.053 (+-1.470) | 439.375 (+-4.014) | 2.043 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=False | | 211.340 (+-2.239) | 438.537 (+-4.143) | 2.075 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=True | 112.593 (+-1.266) | 130.414 (+-1.633) | 446.804 (+-3.283) | 3.426 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=False | | 58.767 (+-0.203) | 374.244 (+-13.598) | 6.368 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=True | 283.210 (+-2.937) | 324.157 (+-1.895) | 720.197 (+-3.467) | 2.222 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=False | | 239.800 (+-2.492) | 592.834 (+-3.903) | 2.472 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=True | 186.255 (+-1.629) | 204.834 (+-1.496) | 787.868 (+-3.648) | 3.846 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=False | | 77.335 (+-0.341) | 651.016 (+-3.926) | 8.418 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=True | 410.286 (+-2.439) | 443.934 (+-2.899) | 1123.923 (+-14.988) | 2.532 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=False | | 312.220 (+-2.307) | 915.347 (+-4.486) | 2.932 (+-0.000) # More test-cases from #90771 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=True | 60.611 (+-0.337) | 80.849 (+-1.780) | 170.465 (+-1.830) | 2.108 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=True | 132.971 (+-1.624) | 164.892 (+-1.426) | 330.971 (+-3.249) | 2.007 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=True | 948.467 (+-3.179) | 891.414 (+-5.282) | 2805.510 (+-25.503) | 3.147 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=True | 52.539 (+-0.327) | 72.471 (+-0.367) | 135.933 (+-1.625) | 1.876 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=True | 138.669 (+-1.867) | 168.628 (+-1.213) | 321.112 (+-2.904) | 1.904 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=True | 689.933 (+-3.175) | 746.911 (+-2.985) | 2050.880 (+-22.188) | 2.746 (+-0.000) 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=False | | 78.347 (+-0.338) | 169.646 (+-1.640) | 2.165 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=False | | 162.194 (+-1.089) | 329.754 (+-2.590) | 2.033 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=False | | 894.476 (+-2.738) | 2815.870 (+-22.589) | 3.148 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=False | | 52.728 (+-0.406) | 112.024 (+-1.225) | 2.125 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=False | | 151.560 (+-1.128) | 299.152 (+-3.353) | 1.974 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=False | | 500.053 (+-4.288) | 1698.601 (+-16.785) | 3.397 (+-0.000) ``` Note: There is no perf regression for other case. There some cases (see Source below) with small speed-ups, for the rest it is roughly around 1.0 +/- 0.1 which may be attributed to noisy measurements ... [Source](https://gist.github.com/vfdev-5/1c0778904a07ce40401306548b9525e8#file-20230322-132441-pr_vs_nightly-speedup-md) ## Context - #90771 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
1 parent 036ed6c commit 76c08bc

File tree

1 file changed

+19
-17
lines changed

1 file changed

+19
-17
lines changed

aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,10 @@ static __m128i inline mm_cvtepu8_epi32(const uint32_t* C10_RESTRICT ptr) {
3939
return _mm_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*)ptr));
4040
}
4141

42-
// TODO: We may want to hard-code an unrolled version for the case where
43-
// num_channels=3 to hint the compiler to vectorize this (looks at original
44-
// PIL-SIMD's code).
4542
at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
4643
// Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
47-
// RGBARGBARGBA format where A is hard-coded to 255. Each pixel is encoded
48-
// into as 32bits. This generalizes to num_channels <= 4 and also works for
44+
// RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
45+
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
4946
// non-channels_last tensors.
5047

5148
const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr<uint8_t>();
@@ -71,7 +68,7 @@ void pack_rgb(
7168
const at::Tensor& unpacked_tensor, // IN
7269
const at::Tensor& packed_tensor // OUT
7370
) {
74-
// Convert back RGBA into RGB.
71+
// Convert from unpacked channels last 4-channels tensor into original data layout.
7572

7673
constexpr int rgba_size = 4;
7774
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
@@ -315,28 +312,31 @@ void upsample_avx_bilinear_uint8(
315312
/*align_i32=*/true);
316313
}
317314

318-
bool is_rgba = num_channels == 4 && input.is_contiguous(at::MemoryFormat::ChannelsLast);
315+
bool needs_unpacking = num_channels == 4 && input.is_contiguous(at::MemoryFormat::ChannelsLast);
319316

320317
at::Tensor buffer_horiz, buffer_vert;
321-
if (need_horizontal && !(is_rgba && !need_vertical)) {
318+
// Minor optimization: we can avoid allocating an extra buffer if we're performing
319+
// horizontal-only or vertical-only interpolation, and if the tensor doesn't
320+
// need unpacking
321+
if (need_horizontal && !(needs_unpacking && !need_vertical)) {
322322
buffer_horiz = at::empty({4, yin, xout}, input.options());
323323
}
324-
if (need_vertical && !is_rgba) {
324+
if (need_vertical && !needs_unpacking) {
325325
buffer_vert = at::empty({4, yout, xout}, input.options());
326326
}
327327

328-
// TODO: The unpack / pack operations create a copy of the original input and
329-
// output tensor. There should be a way to avoid these copies by instead
328+
// TODO: The unpack / pack operations create a
329+
// copy of the original input and output tensor. There should be a way to avoid these copies by instead
330330
// modifying the low-level kernels. Or maybe at least avoid copying the entire
331331
// tensors and just copy part of them (line by line).
332332
for (const auto i : c10::irange(batch_size)) {
333333

334-
at::Tensor unpacked_input = (is_rgba) ? input[i] : unpack_rgb(input[i]);
334+
at::Tensor unpacked_input = (needs_unpacking) ? input[i] : unpack_rgb(input[i]);
335335
at::Tensor unpacked_output;
336336

337337
if (need_horizontal) {
338338

339-
at::Tensor unpacked_output_temp = (is_rgba && !need_vertical) ? output[i] : buffer_horiz;
339+
at::Tensor unpacked_output_temp = (needs_unpacking && !need_vertical) ? output[i] : buffer_horiz;
340340

341341
ImagingResampleHorizontal(
342342
unpacked_output_temp,
@@ -347,7 +347,7 @@ void upsample_avx_bilinear_uint8(
347347
unpacked_output = unpacked_input = unpacked_output_temp;
348348
}
349349
if (need_vertical) {
350-
unpacked_output = (is_rgba) ? output[i] : buffer_vert;
350+
unpacked_output = (needs_unpacking) ? output[i] : buffer_vert;
351351

352352
ImagingResampleVertical(
353353
unpacked_output,
@@ -359,7 +359,7 @@ void upsample_avx_bilinear_uint8(
359359

360360
TORCH_INTERNAL_ASSERT(unpacked_output.defined());
361361

362-
if (!is_rgba) {
362+
if (!needs_unpacking) {
363363
pack_rgb(unpacked_output, output[i]);
364364
}
365365
}
@@ -382,7 +382,8 @@ void ImagingResampleHorizontalConvolution8u4x(
382382
unsigned int coefs_precision) {
383383
// Interpolation horizontal pass processing together 4 vertical lines.
384384
// - Input data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value.
385-
// - We split the size of weight vector for a given output index as a sum: K = n * 4 + m * 2 + k.
385+
// - We split the size of weight vector for a given output index as a sum:
386+
// ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
386387
// - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
387388
// in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
388389

@@ -549,7 +550,8 @@ void ImagingResampleHorizontalConvolution8u(
549550

550551
// Interpolation horizontal pass processing only one vertical line.
551552
// - Input data format is RGBA with R,G,B,A being uint8, we can encode 4 values as a single uint32 value.
552-
// - We split the size of weight vector for a given output index as a sum: K = n * 8 + m * 4 + k * 2 + l.
553+
// - We split the size of weight vector for a given output index as a sum:
554+
// ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
553555
// - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
554556
// in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
555557

0 commit comments

Comments
 (0)