Skip to content

Commit f3e16cc

Browse files
vedanujezyang
authored andcommitted
Expose gradients w.r.t. input & weight for conv1d, conv2d, conv3d in Python (#5408)
This PR addresses issue #5024 * Expose Conv2dBackward in python * Separate interface for exposing gardients of operators * Revert old changes * Add tests * Add conv1d gradients. Refactor tests for grad convolutions * Refactor names and change examples * Remove Varibale from tests for conv backward
1 parent 8317803 commit f3e16cc

File tree

3 files changed

+362
-0
lines changed

3 files changed

+362
-0
lines changed

test/test_nn.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4506,6 +4506,51 @@ def test_conv_double_backward_cuda(self, dtype=torch.FloatTensor):
45064506
"\ninp_size: " + str(inp_size) +
45074507
"\ndilation: " + str(dilation))
45084508

4509+
def run_grad_conv_test(self, func_forward, func_backward, dim=1, gradient='input'):
4510+
for kern, inp_size in [(3, 6), (3, 7), (4, 9)]:
4511+
for batch, stride, padding, chan_in, chan_out, dilation in \
4512+
product([1, 2], [1, 2], [0, 1, 2], [2], [3], [1]):
4513+
4514+
input_shape = [batch, chan_in]
4515+
weight_shape = [chan_out, chan_in]
4516+
for _ in range(dim):
4517+
input_shape.append(inp_size)
4518+
weight_shape.append(kern)
4519+
4520+
input = torch.randn(input_shape, requires_grad=True)
4521+
weight = torch.randn(weight_shape, requires_grad=True)
4522+
output = func_forward(input, weight, stride=stride, padding=padding, dilation=dilation)
4523+
4524+
gradient_o = torch.randn(output.shape)
4525+
gradient_w = torch.autograd.grad(output, input if (gradient == 'input') else weight, gradient_o)
4526+
4527+
self.assertAlmostEqual(gradient_w[0],
4528+
func_backward(
4529+
input_shape if (gradient == 'input') else input,
4530+
weight_shape if (gradient == 'weight') else weight,
4531+
gradient_o,
4532+
stride=stride,
4533+
padding=padding,
4534+
dilation=dilation))
4535+
4536+
def test_grad_conv1d_input(self):
4537+
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_input, 1, 'input')
4538+
4539+
def test_grad_conv1d_weight(self):
4540+
self.run_grad_conv_test(F.conv1d, F.grad.conv1d_weight, 1, 'weight')
4541+
4542+
def test_grad_conv2d_input(self):
4543+
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_input, 2, 'input')
4544+
4545+
def test_grad_conv2d_weight(self):
4546+
self.run_grad_conv_test(F.conv2d, F.grad.conv2d_weight, 2, 'weight')
4547+
4548+
def test_grad_conv3d_input(self):
4549+
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_input, 3, 'input')
4550+
4551+
def test_grad_conv3d_weight(self):
4552+
self.run_grad_conv_test(F.conv3d, F.grad.conv3d_weight, 3, 'weight')
4553+
45094554

45104555
class TestNNInit(TestCase):
45114556
def setUp(self):

torch/nn/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from ._functions.thnn.fold import Col2Im, Im2Col
1515
from torch.autograd import Variable
1616
from .modules.utils import _single, _pair, _triple
17+
from . import grad
1718

1819

1920
conv1d = _add_docstr(torch.conv1d, r"""

torch/nn/grad.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
"""Gradient interface"""
2+
3+
import torch
4+
from .modules.utils import _single, _pair, _triple
5+
6+
7+
def _grad_input_padding(grad_output, input_size, stride, padding, kernel_size):
8+
input_size = list(input_size)
9+
k = grad_output.dim() - 2
10+
11+
if len(input_size) == k + 2:
12+
input_size = input_size[-k:]
13+
if len(input_size) != k:
14+
raise ValueError("input_size must have {} elements (got {})"
15+
.format(k + 2, len(input_size)))
16+
17+
def dim_size(d):
18+
return ((grad_output.size(d + 2) - 1) * stride[d] - 2 * padding[d] +
19+
kernel_size[d])
20+
21+
min_sizes = [dim_size(d) for d in range(k)]
22+
max_sizes = [min_sizes[d] + stride[d] - 1 for d in range(k)]
23+
for size, min_size, max_size in zip(input_size, min_sizes, max_sizes):
24+
if size < min_size or size > max_size:
25+
raise ValueError(
26+
("requested an input grad size of {}, but valid sizes range "
27+
"from {} to {} (for a grad_output of {})").format(
28+
input_size, min_sizes, max_sizes,
29+
grad_output.size()[2:]))
30+
31+
return tuple(input_size[d] - min_sizes[d] for d in range(k))
32+
33+
34+
def conv1d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
35+
r"""
36+
Computes the gradient of conv1d with respect to the input of the convolution.
37+
This is same as the 1D transposed convolution operator under the hood but requires
38+
the shape of the gradient w.r.t. input to be specified explicitly.
39+
40+
Args:
41+
input_size : Shape of the input gradient tensor
42+
weight: weight tensor (out_channels x in_channels/groups x kW)
43+
grad_output : output gradient tensor (minibatch x out_channels x oW)
44+
stride (int or tuple, optional): Stride of the convolution. Default: 1
45+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
46+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
47+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
48+
bias: optional bias tensor (out_channels). Default: None
49+
50+
Examples::
51+
52+
>>> input = torch.randn(1,1,3, requires_grad=True)
53+
>>> weight = torch.randn(1,1,1, requires_grad=True)
54+
>>> output = F.conv1d(input, weight)
55+
>>> grad_output = torch.randn(output.shape)
56+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
57+
>>> F.grad.conv1d_input(input.shape, weight, grad_output)
58+
59+
"""
60+
stride = _single(stride)
61+
padding = _single(padding)
62+
dilation = _single(dilation)
63+
kernel_size = [weight.shape[2]]
64+
65+
if input_size is None:
66+
raise ValueError("grad.conv1d_input requires specifying an input_size")
67+
68+
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
69+
padding, kernel_size)
70+
71+
return torch._C._VariableFunctions.conv_transpose1d(
72+
grad_output, weight, bias, stride, padding, grad_input_padding, groups,
73+
dilation)
74+
75+
76+
def conv1d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
77+
r"""
78+
Computes the gradient of conv1d with respect to the weight of the convolution.
79+
80+
Args:
81+
input: input tensor of shape (minibatch x in_channels x iW)
82+
weight_size : Shape of the weight gradient tensor
83+
grad_output : output gradient tensor (minibatch x out_channels x oW)
84+
stride (int or tuple, optional): Stride of the convolution. Default: 1
85+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
86+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
87+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
88+
bias: optional bias tensor (out_channels). Default: None
89+
90+
Examples::
91+
92+
>>> input = torch.randn(1,1,3, requires_grad=True)
93+
>>> weight = torch.randn(1,1,1, requires_grad=True)
94+
>>> output = F.conv1d(input, weight)
95+
>>> grad_output = torch.randn(output.shape)
96+
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
97+
>>> F.grad.conv1d_weight(input, weight.shape, grad_output)
98+
99+
"""
100+
stride = _single(stride)
101+
padding = _single(padding)
102+
dilation = _single(dilation)
103+
in_channels = input.shape[1]
104+
out_channels = grad_output.shape[1]
105+
min_batch = input.shape[0]
106+
107+
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1)
108+
grad_output = grad_output.contiguous().view(
109+
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2])
110+
111+
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
112+
input.shape[2])
113+
114+
grad_weight = torch._C._VariableFunctions.conv1d(input, grad_output, bias,
115+
dilation, padding, stride,
116+
in_channels * min_batch)
117+
118+
grad_weight = grad_weight.contiguous().view(
119+
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2])
120+
121+
return grad_weight.sum(dim=0).view(
122+
in_channels // groups, out_channels, grad_weight.shape[2]).transpose(
123+
0, 1).narrow(2, 0, weight_size[2])
124+
125+
126+
def conv2d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
127+
r"""
128+
Computes the gradient of conv2d with respect to the input of the convolution.
129+
This is same as the 2D transposed convolution operator under the hood but requires
130+
the shape of the gradient w.r.t. input to be specified explicitly.
131+
132+
Args:
133+
input_size : Shape of the input gradient tensor
134+
weight: weight tensor (out_channels x in_channels/groups x kH x kW)
135+
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
136+
stride (int or tuple, optional): Stride of the convolution. Default: 1
137+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
138+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
139+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
140+
bias: optional bias tensor (out_channels). Default: None
141+
142+
Examples::
143+
144+
>>> input = torch.randn(1,1,3,3, requires_grad=True)
145+
>>> weight = torch.randn(1,1,1,2, requires_grad=True)
146+
>>> output = F.conv2d(input, weight)
147+
>>> grad_output = torch.randn(output.shape)
148+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
149+
>>> F.grad.conv2d_input(input.shape, weight, grad_output)
150+
151+
"""
152+
stride = _pair(stride)
153+
padding = _pair(padding)
154+
dilation = _pair(dilation)
155+
kernel_size = (weight.shape[2], weight.shape[3])
156+
157+
if input_size is None:
158+
raise ValueError("grad.conv2d_input requires specifying an input_size")
159+
160+
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
161+
padding, kernel_size)
162+
163+
return torch._C._VariableFunctions.conv_transpose2d(
164+
grad_output, weight, bias, stride, padding, grad_input_padding, groups,
165+
dilation)
166+
167+
168+
def conv2d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
169+
r"""
170+
Computes the gradient of conv2d with respect to the weight of the convolution.
171+
172+
Args:
173+
input: input tensor of shape (minibatch x in_channels x iH x iW)
174+
weight_size : Shape of the weight gradient tensor
175+
grad_output : output gradient tensor (minibatch x out_channels x oH x oW)
176+
stride (int or tuple, optional): Stride of the convolution. Default: 1
177+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
178+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
179+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
180+
bias: optional bias tensor (out_channels). Default: None
181+
182+
Examples::
183+
184+
>>> input = torch.randn(1,1,3,3, requires_grad=True)
185+
>>> weight = torch.randn(1,1,1,2, requires_grad=True)
186+
>>> output = F.conv2d(input, weight)
187+
>>> grad_output = torch.randn(output.shape)
188+
>>> grad_weight = torch.autograd.grad(output, filter, grad_output)
189+
>>> F.grad.conv2d_weight(input, weight.shape, grad_output)
190+
191+
"""
192+
stride = _pair(stride)
193+
padding = _pair(padding)
194+
dilation = _pair(dilation)
195+
in_channels = input.shape[1]
196+
out_channels = grad_output.shape[1]
197+
min_batch = input.shape[0]
198+
199+
grad_output = grad_output.contiguous().repeat(1, in_channels // groups, 1,
200+
1)
201+
grad_output = grad_output.contiguous().view(
202+
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
203+
grad_output.shape[3])
204+
205+
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
206+
input.shape[2], input.shape[3])
207+
208+
grad_weight = torch._C._VariableFunctions.conv2d(input, grad_output, bias,
209+
dilation, padding, stride,
210+
in_channels * min_batch)
211+
212+
grad_weight = grad_weight.contiguous().view(
213+
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
214+
grad_weight.shape[3])
215+
216+
return grad_weight.sum(dim=0).view(
217+
in_channels // groups, out_channels,
218+
grad_weight.shape[2], grad_weight.shape[3]).transpose(0, 1).narrow(
219+
2, 0, weight_size[2]).narrow(3, 0, weight_size[3])
220+
221+
222+
def conv3d_input(input_size, weight, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
223+
r"""
224+
Computes the gradient of conv3d with respect to the input of the convolution.
225+
This is same as the 3D transposed convolution operator under the hood but requires
226+
the shape of the gradient w.r.t. input to be specified explicitly.
227+
228+
Args:
229+
input_size : Shape of the input gradient tensor
230+
weight: weights tensor (out_channels x in_channels/groups x kT x kH x kW)
231+
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
232+
stride (int or tuple, optional): Stride of the convolution. Default: 1
233+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
234+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
235+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
236+
bias: optional bias tensor (out_channels). Default: None
237+
238+
Examples::
239+
240+
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
241+
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
242+
>>> output = F.conv3d(input, weight)
243+
>>> grad_output = torch.randn(output.shape)
244+
>>> grad_input = torch.autograd.grad(output, input, grad_output)
245+
>>> F.grad.conv3d_input(input.shape, weight, grad_output)
246+
247+
"""
248+
stride = _triple(stride)
249+
padding = _triple(padding)
250+
dilation = _triple(dilation)
251+
kernel_size = (weight.shape[2], weight.shape[3], weight.shape[4])
252+
253+
if input_size is None:
254+
raise ValueError("grad.conv3d_input requires specifying an input_size")
255+
256+
grad_input_padding = _grad_input_padding(grad_output, input_size, stride,
257+
padding, kernel_size)
258+
259+
return torch._C._VariableFunctions.conv_transpose3d(
260+
grad_output, weight, bias, stride, padding, grad_input_padding, groups,
261+
dilation)
262+
263+
264+
def conv3d_weight(input, weight_size, grad_output, stride=1, padding=0, dilation=1, groups=1, bias=None):
265+
r"""
266+
Computes the gradient of conv3d with respect to the weight of the convolution.
267+
268+
Args:
269+
input: input tensor of shape (minibatch x in_channels x iT x iH x iW)
270+
weight_size : Shape of the weight gradient tensor
271+
grad_output : output gradient tensor (minibatch x out_channels x oT x oH x oW)
272+
stride (int or tuple, optional): Stride of the convolution. Default: 1
273+
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
274+
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
275+
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
276+
bias: optional bias tensor (out_channels). Default: None
277+
278+
Examples::
279+
280+
>>> input = torch.randn(2, 8, 10, 10, 20, requires_grad=True)
281+
>>> weight = torch.randn(4, 8, 2, 3, 3, requires_grad=True)
282+
>>> output = F.conv3d(input, weight)
283+
>>> grad_output = torch.randn(output.shape)
284+
>>> grad_weight = torch.autograd.grad(output, weight, grad_output)
285+
>>> F.grad.conv3d_weight(input, weight.shape, grad_output)
286+
287+
"""
288+
stride = _triple(stride)
289+
padding = _triple(padding)
290+
dilation = _triple(dilation)
291+
in_channels = input.shape[1]
292+
out_channels = grad_output.shape[1]
293+
min_batch = input.shape[0]
294+
295+
grad_output = grad_output.repeat(1, in_channels // groups, 1, 1, 1)
296+
grad_output = grad_output.contiguous().view(
297+
grad_output.shape[0] * grad_output.shape[1], 1, grad_output.shape[2],
298+
grad_output.shape[3], grad_output.shape[4])
299+
300+
input = input.contiguous().view(1, input.shape[0] * input.shape[1],
301+
input.shape[2], input.shape[3],
302+
input.shape[4])
303+
304+
grad_weight = torch._C._VariableFunctions.conv3d(input, grad_output, bias,
305+
dilation, padding, stride,
306+
in_channels * min_batch)
307+
308+
grad_weight = grad_weight.contiguous().view(
309+
min_batch, grad_weight.shape[1] // min_batch, grad_weight.shape[2],
310+
grad_weight.shape[3], grad_weight.shape[4])
311+
312+
return grad_weight.sum(dim=0).view(
313+
in_channels // groups, out_channels, grad_weight.shape[2],
314+
grad_weight.shape[3], grad_weight.shape[4]).transpose(0, 1).narrow(
315+
2, 0, weight_size[2]).narrow(3, 0, weight_size[3]).narrow(
316+
4, 0, weight_size[4])

0 commit comments

Comments
 (0)