Skip to content

Commit 85f4d2b

Browse files
li-roysoumith
authored andcommitted
throw error when grid_sample is passed unsupported mode (#8884)
1 parent f74207c commit 85f4d2b

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

test/test_nn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4574,6 +4574,10 @@ def test_cosine_similarity(self):
45744574
input2 = torch.randn(input_size, requires_grad=True)
45754575
self.assertEqual(F.cosine_similarity(input1, input2, dim=1).size(), expected_size)
45764576

4577+
def test_grid_sample_unsupported_mode(self):
4578+
with self.assertRaisesRegex(NotImplementedError, "nn.functional.grid_sample got unsupported mode: 'garbage'"):
4579+
F.grid_sample(torch.tensor([]), torch.tensor([]), mode='garbage')
4580+
45774581
def test_grid_sample(self):
45784582
def test_cpu_against_cuda(N, C, H, W, padding_mode):
45794583
def test_shape(N, C, IH, IW, H, W, padding_mode):

torch/nn/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1902,6 +1902,8 @@ def grid_sample(input, grid, mode='bilinear', padding_mode='zeros'):
19021902
output (Tensor): output Tensor
19031903
19041904
"""
1905+
if mode != 'bilinear':
1906+
raise NotImplementedError("nn.functional.grid_sample got unsupported mode: '{}'".format(mode))
19051907
return vision.grid_sampler(input, grid, padding_mode)
19061908

19071909

0 commit comments

Comments
 (0)