Skip to content

Commit eebed5a

Browse files
committed
Update base for Update on "Improved perfs for vectorized bilinear interpolate cpu uint8 RGB-case (channels last)"
## Description - Based on #96651 - Improved perfs for vectorized **bilinear** interpolate uint8 RGB-case, **channels last** - unified RGB and RGBA processing code such that RGB input is not copied into RGBA - Performances are more close to Pillow-SIMD (labeled as `Pillow (9.0.0.post1)` in the results) - RGBA case perfs are the same after refactoring (see Source link below) - Fixed mem pointer alignment, added more comments (reviews from #96651) ## Results - `Pillow (9.0.0.post1)` == Pillow-SIMD ``` [-------------------------------------------------------------------------------------------------- Resize -------------------------------------------------------------------------------------------------] | Pillow (9.0.0.post1) | torch (2.1.0a0+gitc005105) PR | torch (2.1.0a0+git5309c44) nightly | Speed-up: PR vs nightly 1 threads: -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=True | 38.670 (+-0.445) | 57.366 (+-0.799) | 132.147 (+-1.236) | 2.304 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (32, 32) aa=False | | 37.825 (+-0.417) | 111.789 (+-1.175) | 2.955 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=True | 127.898 (+-1.335) | 153.081 (+-2.346) | 302.518 (+-2.632) | 1.976 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (224, 224) aa=False | | 141.695 (+-1.415) | 286.663 (+-2.494) | 2.023 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=True | 179.735 (+-2.054) | 210.613 (+-3.116) | 439.375 (+-4.014) | 2.086 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (320, 320) aa=False | | 207.601 (+-1.639) | 438.537 (+-4.143) | 2.112 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=True | 112.679 (+-1.321) | 130.863 (+-1.987) | 446.804 (+-3.283) | 3.414 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (32, 32) aa=False | | 57.968 (+-0.270) | 374.244 (+-13.598) | 6.456 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=True | 282.398 (+-3.485) | 322.986 (+-1.947) | 720.197 (+-3.467) | 2.230 (+-0.000) 3 torch.uint8 channels_last bilinear (520, 520) -> (224, 224) aa=False | | 231.625 (+-2.006) | 592.834 (+-3.903) | 2.559 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=True | 185.711 (+-1.666) | 201.069 (+-2.182) | 787.868 (+-3.648) | 3.918 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (32, 32) aa=False | | 75.975 (+-0.696) | 651.016 (+-3.926) | 8.569 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=True | 410.236 (+-6.021) | 451.486 (+-3.939) | 1123.923 (+-14.988) | 2.489 (+-0.000) 3 torch.uint8 channels_last bilinear (712, 712) -> (224, 224) aa=False | | 299.597 (+-1.887) | 915.347 (+-4.486) | 3.055 (+-0.000) # More test-cases from #90771 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=True | 60.751 (+-0.285) | 78.538 (+-1.282) | 170.465 (+-1.830) | 2.170 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=True | 133.619 (+-2.035) | 159.614 (+-1.587) | 330.971 (+-3.249) | 2.074 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=True | 950.243 (+-10.641) | 891.369 (+-17.946) | 2805.510 (+-25.503) | 3.147 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=True | 52.771 (+-0.961) | 72.253 (+-1.020) | 135.933 (+-1.625) | 1.881 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=True | 139.107 (+-2.143) | 165.844 (+-2.177) | 321.112 (+-2.904) | 1.936 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=True | 691.470 (+-9.566) | 764.942 (+-11.192) | 2050.880 (+-22.188) | 2.681 (+-0.000) 3 torch.uint8 channels_last bilinear (64, 64) -> (224, 224) aa=False | | 77.375 (+-1.345) | 169.646 (+-1.640) | 2.193 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (270, 268) aa=False | | 159.115 (+-3.935) | 329.754 (+-2.590) | 2.072 (+-0.000) 3 torch.uint8 channels_last bilinear (256, 256) -> (1024, 1024) aa=False | | 877.248 (+-5.736) | 2815.870 (+-22.589) | 3.210 (+-0.000) 3 torch.uint8 channels_last bilinear (224, 224) -> (64, 64) aa=False | | 53.120 (+-0.316) | 112.024 (+-1.225) | 2.109 (+-0.000) 3 torch.uint8 channels_last bilinear (270, 268) -> (224, 224) aa=False | | 147.330 (+-1.871) | 299.152 (+-3.353) | 2.030 (+-0.000) 3 torch.uint8 channels_last bilinear (1024, 1024) -> (256, 256) aa=False | | 472.182 (+-10.785) | 1698.601 (+-16.785) | 3.597 (+-0.000) ``` Note: for other cases (see Source below) speed-up is roughly around 1.0 +/- 0.1 which may be attributed to noisy measurements ... [Source](https://gist.github.com/vfdev-5/1c0778904a07ce40401306548b9525e8#file-20230320-160044-pr_vs_nightly-speedup-md) ## Context - #90771 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
1 parent 8ce7530 commit eebed5a

File tree

1 file changed

+21
-43
lines changed

1 file changed

+21
-43
lines changed

aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -394,19 +394,15 @@ void ImagingResampleHorizontalConvolution8u4x(
394394
// [ ... r3 g3 b3 a3 r4 g4 b4 a4 | ... R3 G3 B3 A3 R4 G4 B4 A4 ] ->
395395
// [r3 0 r4 0 g3 0 g4 0 b3 0 b4 0 a3 0 a4 0 | R3 0 R4 0 G3 0 G4 0 B3 0 B4 0 A3 0 A4 0]
396396

397-
const __m256i masks_low_high_c4[2] = {
398-
_mm256_set_epi8(
397+
const auto mask_low_c4 = _mm256_set_epi8(
399398
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
400-
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0
401-
),
402-
_mm256_set_epi8(
399+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
400+
const auto mask_high_c4 = _mm256_set_epi8(
403401
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
404-
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8
405-
)
406-
};
402+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
407403

408-
const auto mask_low = masks_low_high_c4[0];
409-
const auto mask_high = masks_low_high_c4[1];
404+
const auto mask_low = mask_low_c4;
405+
const auto mask_high = mask_high_c4;
410406

411407
const auto zero = _mm256_setzero_si256();
412408
const auto initial = _mm256_set1_epi32(1 << (coefs_precision - 1));
@@ -568,44 +564,26 @@ void ImagingResampleHorizontalConvolution8u(
568564
7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4, 7, 6, 5, 4,
569565
3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0, 3, 2, 1, 0);
570566

571-
const __m256i masks_low_high_c4[2] = {
572-
_mm256_set_epi8(
567+
const auto mask_low_c4 = _mm256_set_epi8(
573568
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0,
574-
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0
575-
),
576-
_mm256_set_epi8(
569+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
570+
const auto mask_high_c4 = _mm256_set_epi8(
577571
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
578-
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8
579-
)
580-
};
581-
const __m256i masks_lh_c3_c4[2] = {
582-
_mm256_set_epi8(
583-
-1, -1, -1, -1, -1, 11, -1, 8, -1, 10, -1, 7, -1, 9, -1, 6,
584-
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0
585-
),
586-
_mm256_set_epi8(
572+
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8);
573+
const auto mask_hl_c4 = _mm256_set_epi8(
587574
-1, 15, -1, 11, -1, 14, -1, 10, -1, 13, -1, 9, -1, 12, -1, 8,
588-
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0
589-
)
590-
};
591-
592-
const __m128i masks_low128_c3_c4[2] = {
593-
_mm_set_epi8(
594-
-1, -1, -1, -1, -1, 5, -1, 2, -1, 4, -1, 1, -1, 3, -1, 0
595-
),
596-
_mm_set_epi8(
597-
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0
598-
)
599-
};
600-
601-
const auto mask_low = masks_low_high_c4[0];
602-
const auto mask_high = masks_low_high_c4[1];
603-
const auto mask_hl = masks_lh_c3_c4[1];
604-
const auto mask_low128 = masks_low128_c3_c4[1];
575+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
576+
const auto mask_low128_c4 = _mm_set_epi8(
577+
-1, 7, -1, 3, -1, 6, -1, 2, -1, 5, -1, 1, -1, 4, -1, 0);
578+
579+
const auto mask_low = mask_low_c4;
580+
const auto mask_high = mask_high_c4;
581+
const auto mask_hl = mask_hl_c4;
582+
const auto mask_low128 = mask_low128_c4;
605583

606584
// out_xsize = output width, out_x = output x index
607-
// ids_size = interpolation size
608-
// ids_min = input x start index corresponding to output x index (out_x)
585+
// ids_min is the input offset index corresponding to out_x
586+
// ids_size is the interpolation size for out_x
609587

610588
const auto zero = _mm_setzero_si128();
611589

0 commit comments

Comments
 (0)