Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions aten/src/ATen/native/ReflectionPad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ void reflection_pad1d_out_template(
int64_t dim_plane = 0;
int64_t dim_w = 1;
int64_t nbatch = 1;

TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 2 || input_.ndimension() == 3), "non-empty 2D "
"or 3D (batch mode) tensor expected for input, but got: ", input_);
// allow dim=0 only in the batch dimension.
TORCH_CHECK(
(input_.ndimension() == 2 && input_.size(1) != 0) ||
(input_.ndimension() == 3 && input_.size(1) != 0 && input_.size(2) != 0),
"2D or 3D (batch mode) tensor expected for input, but got: ", input_);

if (input_.ndimension() == 3) {
nbatch = input_.size(0);
Expand Down Expand Up @@ -300,9 +301,11 @@ void reflection_pad2d_out_template(
int dim_slices = 0;
int64_t nbatch = 1;

TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
"4D (batch mode) tensor expected for input, but got: ", input_);
bool valid_dims = input_.size(1) != 0 && input_.size(2) != 0;
TORCH_CHECK(
(input_.ndimension() == 3 && valid_dims) ||
(input_.ndimension() == 4 && valid_dims && input_.size(3) != 0),
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);

if (input_.ndimension() == 4) {
nbatch = input_.size(0);
Expand Down
31 changes: 25 additions & 6 deletions aten/src/ATen/native/cuda/ReflectionPad.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,10 @@ void reflection_pad1d_out_template(
int64_t dim_w = 1;
int64_t nbatch = 1;

TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 2 || input_.ndimension() == 3), "non-empty 2D "
"or 3D (batch mode) tensor expected for input, but got: ", input_);
TORCH_CHECK(
(input_.ndimension() == 2 && input_.size(1) != 0) ||
(input_.ndimension() == 3 && input_.size(1) != 0 && input_.size(2) != 0),
"2D or 3D (batch mode) tensor expected for input, but got: ", input_);

if (input_.ndimension() == 3) {
nbatch = input_.size(0);
Expand All @@ -184,6 +185,9 @@ void reflection_pad1d_out_template(
} else {
output.resize_({nbatch, nplane, output_w});
}
if (output.numel() == 0) {
return;
}

dim3 block_size(output_w > 256 ? 256 : output_w);
dim3 grid_size((int) ::ceil(output_w / 256.0), nplane, nbatch);
Expand All @@ -206,6 +210,10 @@ void reflection_pad1d_backward_out_template(
Tensor & grad_input, const Tensor & grad_output_,
const Tensor & input, IntArrayRef padding) {

if (grad_input.numel() == 0) {
return;
}

TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");

Expand Down Expand Up @@ -252,6 +260,7 @@ void reflection_pad1d_backward_out_template(

void reflection_pad2d_out_template(
Tensor &output, const Tensor &input_, IntArrayRef padding) {

TORCH_CHECK(canUse32BitIndexMath(input_),
"input tensor must fit into 32-bit index math");

Expand All @@ -260,9 +269,11 @@ void reflection_pad2d_out_template(
int dim_w = 2;
int nbatch = 1;

TORCH_CHECK(input_.numel() > 0 &&
(input_.ndimension() == 3 || input_.ndimension() == 4), "non-empty 3D or "
"4D (batch mode) tensor expected for input, but got: ", input_);
bool valid_dims = input_.size(1) != 0 && input_.size(2) != 0;
TORCH_CHECK(
(input_.ndimension() == 3 && valid_dims) ||
(input_.ndimension() == 4 && valid_dims && input_.size(3) != 0),
"3D or 4D (batch mode) tensor expected for input, but got: ", input_);

if (input_.ndimension() == 4) {
nbatch = input_.size(0);
Expand Down Expand Up @@ -302,6 +313,9 @@ void reflection_pad2d_out_template(
} else {
output.resize_({nbatch, nplane, output_h, output_w});
}
if (output.numel() == 0) {
return;
}

Tensor input = input_.contiguous();

Expand All @@ -326,6 +340,11 @@ void reflection_pad2d_out_template(
void reflection_pad2d_backward_out_template(
Tensor &grad_input, const Tensor &grad_output_,
const Tensor &input, IntArrayRef padding) {

if (grad_input.numel() == 0) {
return;
}

TORCH_CHECK(canUse32BitIndexMath(input),
"input tensor must fit into 32-bit index math");
TORCH_CHECK(canUse32BitIndexMath(grad_output_),
Expand Down
17 changes: 17 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9378,6 +9378,23 @@ def test_GroupNorm_empty(self, device):
with torch.backends.cudnn.flags(enabled=False):
self._test_module_empty_input(mod, inp)

@onlyOnCPUAndCUDA
def test_ReflectionPad_empty(self, device):
for mod, inp in [
(torch.nn.ReflectionPad1d(2), torch.randn(0, 3, 10, device=device)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you're covering all the cases? Don't you need to check ReflectionPad1D with a 2D and 3D input with a zero batch size plus ReflectionPad2D with a 3D and 4D input with a zero batch size?

You should also add test cases that fail, where the non-batch dimension is zero and you assert you hit the appropriate error (self.assertRaisesRegex(...)), and a check for backward when the batch dim is zero.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yes good point. Will add them thank you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The checks for 2D/3D input are done in the python code in functional.py and the ReflectionPad1D does not get called unless the input is specifically dim=3. Same goes for 2D reflection pad (with dim=4). Therefore such tests are not needed.

However I have added tests for non-batch dimension being zero. The backward is tested in the _test_module_empty_input.

(torch.nn.ReflectionPad2d(2), torch.randn(0, 3, 10, 10, device=device))]:
self._test_module_empty_input(mod, inp, check_size=False)

with self.assertRaisesRegex(RuntimeError, '2D or 3D'):
mod = torch.nn.ReflectionPad1d(2)
inp = torch.randn(3, 0, 10, device=device)
mod(inp)

with self.assertRaisesRegex(RuntimeError, '3D or 4D'):
mod = torch.nn.ReflectionPad2d(2)
inp = torch.randn(3, 0, 10, 10, device=device)
mod(inp)

def test_BatchNorm_empty(self, device):
mod = torch.nn.BatchNorm2d(3).to(device)
inp = torch.randn(0, 3, 2, 2, device=device)
Expand Down