Skip to content

Conversation

@ssnl
Copy link
Collaborator

@ssnl ssnl commented Mar 21, 2018

Add align_corners option to upsampling module & functional when using linearly interpolating modes:

When align_corners=True, it uses the old original upsampling scheme, which gives visually better results,
but doesn't properly align input and output pixels, and thus cause the output vary basing on input.
This PR adds this align_corners option, and changes the default behavior to align_corners=False, with
proper warning if this option is not specified upon using nn.Upsample or nn.functional.upsample to let
be aware of this new change.
Adds tests in test_nn.py for spatial invariance when align_corners=False, and usual module tests for
align_corners=False.

The ratio is basically computed as:

ratio = align_corners ? (input_size - 1) / (output_size - 1) : input_size / output_size

And src_idx is:

if align_corners:
  src_idx = dst_idx * ratio
else:
  src_idx = (dst_idx + 0.5) * ratio - 0.5

The 0.5 are used to cast the index to location of the pixel centers.

This also makes the default upsampling behavior consistent with other DL frameworks like tf.

This solves the issue raised in #5511

cc @Dorimer

@ssnl
Copy link
Collaborator Author

ssnl commented Mar 22, 2018

@pytorchbot retest this please

@ssnl ssnl force-pushed the bilinear_upsampling_fix branch 3 times, most recently from 534a7f6 to 981e339 Compare March 22, 2018 22:25
@ssnl
Copy link
Collaborator Author

ssnl commented Mar 22, 2018

@pytorchbot retest this please

@ssnl
Copy link
Collaborator Author

ssnl commented Mar 23, 2018

@Dorimer Let me know if the math formula in this PR looks okay to you.

@ezyang ezyang added the module: bc-breaking Related to a BC-breaking change label Mar 23, 2018

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

test/test_nn.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@ezyang
Copy link
Contributor

ezyang commented Mar 23, 2018

Please come to an agreement about the formulas before merging, thanks!

@vguzov
Copy link

vguzov commented Mar 24, 2018

@ssnl,

Let me know if the math formula in this PR looks okay to you.

Yes, I verified a formula and it looks good. However, I discovered another problem in the new behaviour – it sometimes can produce negative values: https://gist.github.com/Dorimer/899faff586f27070ce55f1cc78c69ce1
But this can be easily fixed, just replace

return scale * (dst_index + 0.5) - 0.5;

with

float src_index = scale * (dst_index + 0.5) - 0.5;
return (src_index < 0) ? 0.f : src_index;

Similar in CUDA code.

vguzov and others added 4 commits March 23, 2018 23:19
… linearly interpolating modes

When align_corners=True, it uses the old original upsampling scheme, which gives visually better results,
but doesn't properly align input and output pixels, and thus cause the output vary basing on input.
This PR adds this align_corners option, and changes the default behavior to align_corners=False, with
proper warning if this option is not specified upon using nn.Upsample or nn.functional.upsample to let
be aware of this new change.
Adds tests in test_nn.py for spatial invariance when align_corners=False, and usual module tests for
align_corners=False.
@ssnl ssnl force-pushed the bilinear_upsampling_fix branch from fdafbd6 to c249518 Compare March 24, 2018 03:19
@ezyang ezyang merged commit 5d77709 into pytorch:master Mar 24, 2018
@ssnl ssnl deleted the bilinear_upsampling_fix branch March 24, 2018 16:57
@shelhamer
Copy link

@Dorimer @ssnl for what it's worth I have been handling bilinear interpolation by ConvTranspose2d analogously to how it is done by DeconvolutionLayer in Caffe. See here for reference code: https://gist.github.com/shelhamer/2d07f3afea5d9628f530af26e3846858.

sighingnow added a commit to sighingnow/pytorch that referenced this pull request Mar 25, 2018
* upstream/master: (663 commits)
  Fix "command not found" error in perf test (pytorch#5982)
  add pip mkl-devel to the error message when mkl is found but mkl headers are not (pytorch#5984)
  Support batch LowerCholeskyTransform (pytorch#5980)
  Linearly interpolating upsampling fix (pytorch#5927)
  Store perf numbers in S3 (pytorch#5951)
  Modidy setup docs for Windows (pytorch#5981)
  Group Normalization (pytorch#5968)
  [distributions] Implement Power transform (pytorch#5976)
  Disable TestBottleneck test_cuda on Windows (pytorch#5977)
  Fix crash when cat-ing empty cuda tensors (pytorch#5971)
  Update no_unions flag for nanopb gen and update ONNX proto files (pytorch#5972)
  Expose gradients w.r.t. input & weight for conv1d, conv2d, conv3d in Python (pytorch#5408)
  Fixed non-determinate preprocessing on DataLoader (pytorch#4640)
  add AVX2 implementation for sigmoid function (pytorch#5010)
  Implement torch.util.bottleneck (pytorch#5216)
  Remove pragma once from cpp file (pytorch#5965)
  fix mvn docs (pytorch#5967)
  Fix incorrect rendering of Tensor.index_*_ doc examples. (pytorch#5969)
  Implement range for loop in script (pytorch#5827)
  Add windows doc (pytorch#5859)
  ...

# Conflicts:
#	aten/src/TH/generic/THTensorMath.c
#	torch/_tensor_docs.py
#	torch/csrc/generic/methods/TensorCompare.cwrap
@zou3519 zou3519 mentioned this pull request Aug 28, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: bc-breaking Related to a BC-breaking change open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants