Skip to content

Conversation

@NicolasHug
Copy link
Member

@NicolasHug NicolasHug commented Oct 27, 2022

This is heavily adapted / ported from PILLOW-SIMD.

This PR adds support for uint8 images in the following case:

  • mode=bilinear, antialias=True -- support for other modes is possible
  • shape=(1, 3, H, W) -- can extend to batch_size > 1. Not sure about channels yet.
  • layout = channels_last --- we could easily support contiguous as well, since we have to copy (pack / unpack) the input anyway
  • device=CPU

This may sound restrictive, but it's not. This is exactly the setting in which torchvision' Resize() is used for training jobs.

This is still WIP with lots of TODOs, but it seems to be working decently. On the inputs I tried and comparing with torchvision's Resize() (which first converts uint8 to floats, runs interpolate(), and converts back to uint8), I'm getting ~3-5X speedup. It seems correct so far as the absolute difference in the outputs is never > 1, and only ~15% of the pixel values differ (by exactly 1).

This addresses pytorch/vision#2289

@vfdev-5 @mingfeima @fmassa I'd love your initial thoughts on this!

cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2022

@NicolasHug NicolasHug changed the title WIP Add uint8 support for interpolate on channels_last (1, 3, H,W) CPU images WIP Add uint8 support for interpolate on channels_last (1, 3, H,W) CPU images, mode=bilinear, antialias=True Oct 27, 2022
@vadimkantorov
Copy link
Contributor

vadimkantorov commented Oct 27, 2022

Is it vectorizing the reads/writes? If so, may also be nice to support 4-channel inputs (RGBA?) or any divisible by 4/8 inputs - probably vectorization is even simpler in this case

Also, interpolate support for 1-channel uint8/int16/uint32 inputs are useful (for label-maps (nearest) and for audio signal interpolation)

@mingfeima
Copy link
Collaborator

Is this a migration of PIL-simd's kernel? Shall we parallel it at the same time, shouldn't be too much additional job ~

@NicolasHug
Copy link
Member Author

Thanks for the feedback @vadimkantorov. Yes, the vectorized part is roughly output[i] = sum_j (wj * input[j]). The pre-computation of the weights and the index mapping isn't vectorized. Regarding the next feature to support, we could add these to our backlog for sure. I'll have to gauge interest internally to decide what to prioritize though.

@mingfeima yes this is a direct port from PIL-SIMD. Indeed I think we should be able to parallelize over the batch-size. Did you have something else in mind?

@vadimkantorov
Copy link
Contributor

Maybe I wasn't clear, I meant is the read-ins and write-outs themselves vectorized? Are all three R,G,B channels read and written in one go? For uint8 3-channels, maybe they just fit in one uint32, but for 3-channels float32, SSE-registers may be used. The memory access would probably be not very aligned, as the memory is read by triplets (and not quadruplets), but maybe in modern processors it's not hurting much.

@NicolasHug
Copy link
Member Author

Hm, I'm not sure what you mean by read-ins and write-outs honestly. If you mean the packing/unpacking of the input/output, that part isn't vectorized. The vectorized part is the writing to the unpacked output from the unpacked input.

Are all three R,G,B channels read and written in one go?

Yes.

@vadimkantorov
Copy link
Contributor

I guess packing/unpacking part wasn't very clear. Is it first copying the memory in unpacked format and then processing it in vectorized way? I guess my question is whether this first unpacking is needed and can be fused with processing (but this would incur some useless computation for non-existing 4-th channel).

@NicolasHug
Copy link
Member Author

NicolasHug commented Oct 28, 2022

Is it first copying the memory in unpacked format and then processing it in vectorized way

Yes, the input tensor arrives as R G B R G B R G B... where each letter is a uint8. For the vectorized code to run, we have to first unpack all RGB triplets into 32bits (the last 8 bits are set to 255). So this becomes R G B P R G B P ... where P is the 255 padding (255 is arbitrary, I guess it's just for consistency with RGBA images). The output is written in that same unpacked format and needs to be re-packed to get a proper tensor.

This does involve extra copies, and there might be opportunities to improve this in the future. But for now, even with the extra copies this is still way faster than torchvision.Resize(), and better than no uint8 support at all; so it's definitely worth it.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Oct 28, 2022

Then yes, maybe in the futures copies can be avoided and can be fused directly with reading/writing (probably worth adding a todo or explanation about this in the code...)

@vadimkantorov
Copy link
Contributor

hope soon pillow is not needed for simple pipelines :)


separable_upsample_generic_Nd_kernel_impl<2, scale_t, HelperInterpLinear>(
output, input, align_corners, {scales_h, scales_w});
if (input.dtype() == at::kByte) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want to use something like _use_vectorized_kernel_cond here (a few functions in the file seem to be using that condition to decide whether or not to vectorize)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

synced offline -- currently the focus is to only support this for uint8. In the future, we can extend this to other data types.

return unpacked_output_p;
}

void beepidiboop(const Tensor& input, const Tensor& output) {
Copy link
Contributor

@anjali411 anjali411 Oct 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should add a template for scalar type here

// - There's a segfault when input_shape == output_shape
// - This could be extended to other filters, not just bilinear
// - License?
beepidiboop(input, output);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call this with the appropriate AT_DISPATCH_{} macro once the function is templated

UINT32 *lineOut, UINT32 * imIn,
int xmin, int xmax, INT16 *k, int coefs_precision, int xin)
{
#ifdef CPU_CAPABILITY_AVX2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a comment sharing the link from where this was borrowed

@anjali411 anjali411 requested a review from ngimel October 31, 2022 16:48
return 0.0;
}

void unpack_rgb(uint8_t * unpacked, const uint8_t * packed, int num_pixels)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh ok, this is a trick to perform vector reads / writes in the kernel with uint32_t more easily. This means that we could also "easily" support num_channels <= 4 by just changing this function maybe? It wouldn't be the most efficient implementation, but might be worth checking if it would be faster for num_channels==1 compared to what we already have


/* coefficient buffer */
/* malloc check ok, overflow checked above */
kk = (double *)malloc(outSize * ksize * sizeof(double));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All malloc should use ideally PyTorch's allocator, so that you don't need to handle the frees yourself

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Nov 3, 2022
@NicolasHug
Copy link
Member Author

Closing this in favor of #90771 which is more complete

@NicolasHug NicolasHug closed this Dec 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants