Skip to content

Commit 1318945

Browse files
committed
Gradient function for Conv2D
1 parent d3002c0 commit 1318945

File tree

5 files changed

+122
-2
lines changed

5 files changed

+122
-2
lines changed

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,50 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
9797
};
9898
}
9999

100+
/// <summary>
101+
/// Gradient function for Conv2D.
102+
/// </summary>
103+
/// <param name="op"></param>
104+
/// <param name="grads"></param>
105+
/// <returns></returns>
106+
[RegisterGradient("Conv2D")]
107+
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
108+
{
109+
var dilations = op.get_attr("dilations");
110+
var strides = op.get_attr("strides");
111+
var padding = op.get_attr("padding");
112+
var explicit_paddings = op.get_attr("explicit_paddings");
113+
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu");
114+
var data_format = op.get_attr("data_format");
115+
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
116+
117+
return new Tensor[]
118+
{
119+
gen_nn_ops.conv2d_backprop_input(new Conv2dParams
120+
{
121+
InputSizes = shape[0],
122+
Filter = op.inputs[1],
123+
Dilations = dilations == null ? null : dilations as int[],
124+
Strides = strides == null ? null : strides as int[],
125+
Padding = padding.ToString(),
126+
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
127+
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
128+
DataFormat = data_format.ToString()
129+
}),
130+
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
131+
{
132+
Input = op.inputs[0],
133+
FilterSizes = shape[1],
134+
Dilations = dilations == null ? null : dilations as int[],
135+
Strides = strides == null ? null : strides as int[],
136+
Padding = padding.ToString(),
137+
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
138+
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
139+
DataFormat = data_format.ToString()
140+
})
141+
};
142+
}
143+
100144
private static bool IsZero(Tensor g)
101145
{
102146
if (new string[] { "ZerosLike", "Zeros" }.Contains(g.op.type))

src/TensorFlowNET.Core/Operations/NnOps/Conv2dParams.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,27 @@ public class Conv2dParams
2222
/// </summary>
2323
public Tensor Input { get; set; }
2424

25+
/// <summary>
26+
/// An integer vector representing the shape of `input`
27+
/// </summary>
28+
public Tensor InputSizes { get; set; }
29+
2530
/// <summary>
2631
/// A 4-D tensor of shape
2732
/// </summary>
2833
public Tensor Filter { get; set; }
2934

35+
/// <summary>
36+
/// An integer vector representing the tensor shape of `filter`
37+
/// </summary>
38+
public Tensor FilterSizes { get; set; }
39+
40+
/// <summary>
41+
/// A `Tensor`. Must have the same type as `filter`.
42+
/// 4-D with shape `[batch, out_height, out_width, out_channels]`.
43+
/// </summary>
44+
public Tensor OutBackProp { get; set; }
45+
3046
/// <summary>
3147
/// The stride of the sliding window for each
3248
/// dimension of `input`. The dimension order is determined by the value of

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,52 @@ public static Tensor conv2d(Conv2dParams parameters)
4343
});
4444

4545
return _op.outputs[0];
46+
}
47+
48+
/// <summary>
49+
/// Computes the gradients of convolution with respect to the filter.
50+
/// </summary>
51+
/// <param name="parameters"></param>
52+
/// <returns></returns>
53+
public static Tensor conv2d_backprop_filter(Conv2dParams parameters)
54+
{
55+
var _op = _op_def_lib._apply_op_helper("Conv2DBackpropFilter", name: parameters.Name, args: new
56+
{
57+
input = parameters.Input,
58+
filter_sizes = parameters.FilterSizes,
59+
out_backprop = parameters.OutBackProp,
60+
strides = parameters.Strides,
61+
padding = parameters.Padding,
62+
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
63+
explicit_paddings = parameters.ExplicitPaddings,
64+
data_format = parameters.DataFormat,
65+
dilations = parameters.Dilations
66+
});
67+
68+
return _op.outputs[0];
69+
}
70+
71+
/// <summary>
72+
/// Computes the gradients of convolution with respect to the input.
73+
/// </summary>
74+
/// <param name="parameters"></param>
75+
/// <returns></returns>
76+
public static Tensor conv2d_backprop_input(Conv2dParams parameters)
77+
{
78+
var _op = _op_def_lib._apply_op_helper("Conv2DBackpropInput", name: parameters.Name, args: new
79+
{
80+
input_sizes = parameters.InputSizes,
81+
filter = parameters.Filter,
82+
out_backprop = parameters.OutBackProp,
83+
strides = parameters.Strides,
84+
padding = parameters.Padding,
85+
use_cudnn_on_gpu = parameters.UseCudnnOnGpu,
86+
explicit_paddings = parameters.ExplicitPaddings,
87+
data_format = parameters.DataFormat,
88+
dilations = parameters.Dilations
89+
});
90+
91+
return _op.outputs[0];
4692
}
4793

4894
public static Tensor bias_add(Tensor value,

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,19 @@ public static Tensor shape(Tensor input, TF_DataType out_type = TF_DataType.TF_I
252252
return _op.outputs[0];
253253
}
254254

255+
/// <summary>
256+
/// Returns shape of tensors.
257+
/// </summary>
258+
/// <param name="input"></param>
259+
/// <param name="out_type"></param>
260+
/// <param name="name"></param>
261+
/// <returns></returns>
262+
public static Tensor[] shape_n(Tensor[] input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
263+
{
264+
var _op = _op_def_lib._apply_op_helper("ShapeN", name, new { input, out_type });
265+
return _op.outputs;
266+
}
267+
255268
public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_INT32, string name = null)
256269
{
257270
var _op = _op_def_lib._apply_op_helper("Size", name, new { input, out_type });

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ Docs: https://tensorflownet.readthedocs.io</Description>
2020
<AssemblyVersion>0.8.1.0</AssemblyVersion>
2121
<PackageReleaseNotes>Changes since v0.8:
2222

23-
1. Removed global static graph instance.
24-
2. Provide custom gradient function.</PackageReleaseNotes>
23+
1. Remove global static graph instance.
24+
2. Provide custom gradient function.
25+
3. Add gradient function for Conv2D.</PackageReleaseNotes>
2526
<LangVersion>7.2</LangVersion>
2627
<FileVersion>0.8.1.0</FileVersion>
2728
</PropertyGroup>

0 commit comments

Comments
 (0)