Skip to content

Commit bd39def

Browse files
committed
Update on "Improved perfs for vectorized bilinear interpolate cpu uint8 RGB-case (channels last)"
## Description - Based on #96651 - Improved perfs for vectorized **bilinear** interpolate uint8 RGB-case, **channels last** - unified RGB and RGBA processing code such that RGB input is not copied into RGBA - Performances are more close to Pillow-SIMD (labeled as `Pillow (9.0.0.post1)` in the results) - RGBA case perfs are the same after refactoring (see Source link below) - Fixed mem pointer alignment, added more comments (reviews from #96651) ## Results - `Pillow (9.0.0.post1)` == Pillow-SIMD ``` [-------------------------------------------------------------------------------------------------- Resize -------------------------------------------------------------------------------------------------] | Pillow (9.0.0.post1) | torch (2.1.0a0+gitce4be01) PR | torch (2.1.0a0+git5309c44) nightly | Speed-up: PR vs nightly 1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=True | 38.548 (+-0.280) | 57.536 (+-0.210) | 132.147 (+-1.236) | 2.297 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=False | | 38.532 (+-0.219) | 111.789 (+-1.175) | 2.901 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=True | 127.689 (+-1.348) | 156.262 (+-1.213) | 302.518 (+-2.632) | 1.936 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=False | | 145.483 (+-1.077) | 286.663 (+-2.494) | 1.970 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=True | 178.117 (+-1.956) | 215.053 (+-1.470) | 439.375 (+-4.014) | 2.043 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=False | | 211.340 (+-2.239) | 438.537 (+-4.143) | 2.075 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=True | 112.593 (+-1.266) | 130.414 (+-1.633) | 446.804 (+-3.283) | 3.426 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=False | | 58.767 (+-0.203) | 374.244 (+-13.598) | 6.368 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=True | 283.210 (+-2.937) | 324.157 (+-1.895) | 720.197 (+-3.467) | 2.222 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=False | | 239.800 (+-2.492) | 592.834 (+-3.903) | 2.472 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=True | 186.255 (+-1.629) | 204.834 (+-1.496) | 787.868 (+-3.648) | 3.846 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=False | | 77.335 (+-0.341) | 651.016 (+-3.926) | 8.418 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=True | 410.286 (+-2.439) | 443.934 (+-2.899) | 1123.923 (+-14.988) | 2.532 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=False | | 312.220 (+-2.307) | 915.347 (+-4.486) | 2.932 (+-0.000) # More test-cases from #90771 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=True | 60.611 (+-0.337) | 80.849 (+-1.780) | 170.465 (+-1.830) | 2.108 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=True | 132.971 (+-1.624) | 164.892 (+-1.426) | 330.971 (+-3.249) | 2.007 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=True | 948.467 (+-3.179) | 891.414 (+-5.282) | 2805.510 (+-25.503) | 3.147 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=True | 52.539 (+-0.327) | 72.471 (+-0.367) | 135.933 (+-1.625) | 1.876 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=True | 138.669 (+-1.867) | 168.628 (+-1.213) | 321.112 (+-2.904) | 1.904 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=True | 689.933 (+-3.175) | 746.911 (+-2.985) | 2050.880 (+-22.188) | 2.746 (+-0.000) 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=False | | 78.347 (+-0.338) | 169.646 (+-1.640) | 2.165 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=False | | 162.194 (+-1.089) | 329.754 (+-2.590) | 2.033 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=False | | 894.476 (+-2.738) | 2815.870 (+-22.589) | 3.148 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=False | | 52.728 (+-0.406) | 112.024 (+-1.225) | 2.125 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=False | | 151.560 (+-1.128) | 299.152 (+-3.353) | 1.974 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=False | | 500.053 (+-4.288) | 1698.601 (+-16.785) | 3.397 (+-0.000) ``` Note: There is no perf regression for other case. There some cases (see Source below) with small speed-ups, for the rest it is roughly around 1.0 +/- 0.1 which may be attributed to noisy measurements ... [Source](https://gist.github.com/vfdev-5/1c0778904a07ce40401306548b9525e8#file-20230322-132441-pr_vs_nightly-speedup-md) ## Context - #90771 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
2 parents 07d7584 + 76c08bc commit bd39def

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,10 @@ static inline void _write_endline_rgb_as_uint32(
6060
std::memcpy(output, data_ptr, 4);
6161
}
6262

63-
// TODO: We may want to hard-code an unrolled version for the case where
64-
// num_channels=3 to hint the compiler to vectorize this (looks at original
65-
// PIL-SIMD's code).
6663
at::Tensor unpack_rgb(const at::Tensor& packed_tensor) {
6764
// Convert a "packed" tensor (typically RGBRGBRGB if channels_last) into
68-
// RGBARGBARGBA format where A is hard-coded to 255. Each pixel is encoded
69-
// into as 32bits. This generalizes to num_channels <= 4 and also works for
65+
// RGBARGBARGBA format where A is hard-coded to 0. Each pixel is encoded
66+
// into as 32 bits. This generalizes to num_channels <= 4 and also works for
7067
// non-channels_last tensors.
7168

7269
const uint8_t* packed = (const uint8_t*)packed_tensor.data_ptr<uint8_t>();
@@ -92,7 +89,7 @@ void pack_rgb(
9289
const at::Tensor& unpacked_tensor, // IN
9390
const at::Tensor& packed_tensor // OUT
9491
) {
95-
// Convert back RGBA into RGB.
92+
// Convert from unpacked channels last 4-channels tensor into original data layout.
9693

9794
constexpr int rgba_size = 4;
9895
uint8_t* unpacked = (uint8_t*)unpacked_tensor.data_ptr<uint8_t>();
@@ -326,11 +323,11 @@ void upsample_avx_bilinear_uint8(
326323
std::vector<at::Tensor> horiz_indices_weights, vert_indices_weights;
327324
unsigned int horiz_weights_precision, vert_weights_precision;
328325

329-
bool is_rgb_or_rgba = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
326+
bool needs_unpacking = (num_channels == 3 || num_channels == 4) && input.is_contiguous(at::MemoryFormat::ChannelsLast);
330327

331328
if (need_horizontal) {
332329
int interp_dim = 3;
333-
auto stride = (is_rgb_or_rgba) ? num_channels : 4;
330+
auto stride = (needs_unpacking) ? num_channels : 4;
334331
std::tie(horiz_indices_weights, ksize_horiz, horiz_weights_precision) =
335332
F::compute_indices_int16_weights_aa(
336333
/*input_size=*/xin,
@@ -346,7 +343,7 @@ void upsample_avx_bilinear_uint8(
346343

347344
if (need_vertical) {
348345
int interp_dim = 2;
349-
auto stride = (is_rgb_or_rgba) ? num_channels * xout : 4 * xout;
346+
auto stride = (needs_unpacking) ? num_channels * xout : 4 * xout;
350347
std::tie(vert_indices_weights, ksize_vert, vert_weights_precision) =
351348
F::compute_indices_int16_weights_aa(
352349
/*input_size=*/yin,
@@ -361,28 +358,28 @@ void upsample_avx_bilinear_uint8(
361358
}
362359

363360
at::Tensor buffer_horiz, buffer_vert;
364-
if (need_horizontal && !(is_rgb_or_rgba && !need_vertical)) {
365-
auto c = (is_rgb_or_rgba) ? num_channels : 4;
361+
if (need_horizontal && !(needs_unpacking && !need_vertical)) {
362+
auto c = (needs_unpacking) ? num_channels : 4;
366363
buffer_horiz = at::empty({c, yin, xout}, input.options());
367364
}
368-
if (need_vertical && !is_rgb_or_rgba) {
369-
auto c = (is_rgb_or_rgba) ? num_channels : 4;
365+
if (need_vertical && !needs_unpacking) {
366+
auto c = (needs_unpacking) ? num_channels : 4;
370367
buffer_vert = at::empty({c, yout, xout}, input.options());
371368
}
372369

373-
// TODO: The unpack / pack operations create a copy of the original input and
374-
// output tensor. There should be a way to avoid these copies by instead
370+
// TODO: The unpack / pack operations create a
371+
// copy of the original input and output tensor. There should be a way to avoid these copies by instead
375372
// modifying the low-level kernels. Or maybe at least avoid copying the entire
376373
// tensors and just copy part of them (line by line).
377374
for (const auto i : c10::irange(batch_size)) {
378375

379-
at::Tensor unpacked_input = (is_rgb_or_rgba) ? input[i] : unpack_rgb(input[i]);
376+
at::Tensor unpacked_input = (needs_unpacking) ? input[i] : unpack_rgb(input[i]);
380377
at::Tensor unpacked_output;
381378

382379
if (need_horizontal) {
383-
at::Tensor unpacked_output_temp = (is_rgb_or_rgba && !need_vertical) ? output[i] : buffer_horiz;
380+
at::Tensor unpacked_output_temp = (needs_unpacking && !need_vertical) ? output[i] : buffer_horiz;
384381

385-
if (is_rgb_or_rgba && num_channels == 3) {
382+
if (needs_unpacking && num_channels == 3) {
386383
ImagingResampleHorizontal<3>(
387384
unpacked_output_temp,
388385
unpacked_input,
@@ -400,7 +397,7 @@ void upsample_avx_bilinear_uint8(
400397
unpacked_output = unpacked_input = unpacked_output_temp;
401398
}
402399
if (need_vertical) {
403-
unpacked_output = (is_rgb_or_rgba) ? output[i] : buffer_vert;
400+
unpacked_output = (needs_unpacking) ? output[i] : buffer_vert;
404401

405402
ImagingResampleVertical(
406403
unpacked_output,
@@ -413,7 +410,7 @@ void upsample_avx_bilinear_uint8(
413410

414411
TORCH_INTERNAL_ASSERT(unpacked_output.defined());
415412

416-
if (!is_rgb_or_rgba) {
413+
if (!needs_unpacking) {
417414
pack_rgb(unpacked_output, output[i]);
418415
}
419416
}
@@ -441,7 +438,8 @@ void ImagingResampleHorizontalConvolution8u4x(
441438
// Interpolation horizontal pass processing together 4 vertical lines.
442439
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
443440
// we can encode 4 values as a single uint32 value.
444-
// - We split the size of weight vector for a given output index as a sum: K = n * 4 + m * 2 + k.
441+
// - We split the size of weight vector for a given output index as a sum:
442+
// ids_size = num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1.
445443
// - We load and process 4 weights values in a loop ("block 4") then we process 2 weights values
446444
// in another loop ("block 2") and finally we process 1 weights value in the final loop ("block 1").
447445

@@ -740,7 +738,8 @@ void ImagingResampleHorizontalConvolution8u(
740738
// Interpolation horizontal pass processing only one vertical line.
741739
// - Input data format is RGBA or RGB with R,G,B,A being uint8. In case of RGBA
742740
// we can encode 4 values as a single uint32 value.
743-
// - We split the size of weight vector for a given output index as a sum: K = n * 8 + m * 4 + k * 2 + l.
741+
// - We split the size of weight vector for a given output index as a sum:
742+
// ids_size = num_blocks_8 * 8 + num_blocks_4 * 4 + num_blocks_2 * 2 + num_blocks_1
744743
// - We load and process 8 weights values in a loop ("block 8") then 4 weights and 2 weights values in
745744
// in another loops ("block 4" and "block 2") and finally we process 1 weight value in the final loop ("block 1").
746745

0 commit comments

Comments
 (0)