Skip to content

Commit 8e2e31c

Browse files
committed
add Gradient for GatherV2, MaxPool, op.
1 parent 1318945 commit 8e2e31c

File tree

15 files changed

+224
-61
lines changed

15 files changed

+224
-61
lines changed

TensorFlow.NET.sln

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Keras.UnitTest", "test\Kera
1717
EndProject
1818
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "TensorFlowNET.Examples.FSharp", "test\TensorFlowNET.Examples.FSharp\TensorFlowNET.Examples.FSharp.fsproj", "{62BC3801-F0D3-44A9-A0AC-712F40C8F961}"
1919
EndProject
20+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{92762DCB-64C8-41B4-BEF7-780A969CE68F}"
21+
EndProject
2022
Global
2123
GlobalSection(SolutionConfigurationPlatforms) = preSolution
2224
Debug|Any CPU = Debug|Any CPU
@@ -51,6 +53,10 @@ Global
5153
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Debug|Any CPU.Build.0 = Debug|Any CPU
5254
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.ActiveCfg = Release|Any CPU
5355
{62BC3801-F0D3-44A9-A0AC-712F40C8F961}.Release|Any CPU.Build.0 = Release|Any CPU
56+
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
57+
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Debug|Any CPU.Build.0 = Debug|Any CPU
58+
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.ActiveCfg = Release|Any CPU
59+
{92762DCB-64C8-41B4-BEF7-780A969CE68F}.Release|Any CPU.Build.0 = Release|Any CPU
5460
EndGlobalSection
5561
GlobalSection(SolutionProperties) = preSolution
5662
HideSolutionNode = FALSE

src/TensorFlowNET.Core/APIs/tf.layers.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ public static Tensor dense(Tensor inputs,
142142

143143
var layer = new Dense(units, activation,
144144
use_bias: use_bias,
145+
bias_initializer: bias_initializer,
145146
kernel_initializer: kernel_initializer);
146147

147148
return layer.apply(inputs);
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Framework
6+
{
7+
/// <summary>
8+
/// Abstract base class for Tensor-like objects that are composed from Tensors.
9+
/// </summary>
10+
public abstract class CompositeTensor
11+
{
12+
}
13+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Framework
6+
{
7+
/// <summary>
8+
/// A sparse representation of a set of tensor slices at given indices.
9+
/// </summary>
10+
public class IndexedSlices : CompositeTensor
11+
{
12+
Tensor _values;
13+
public Tensor values => _values;
14+
15+
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
16+
{
17+
18+
}
19+
20+
public static implicit operator Tensor(IndexedSlices indexedSlices)
21+
{
22+
return indexedSlices.values;
23+
}
24+
}
25+
}

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System.Collections.Generic;
33
using System.Linq;
44
using System.Text;
5+
using Tensorflow.Framework;
56
using Tensorflow.Operations;
67
using static Tensorflow.Python;
78

@@ -42,9 +43,9 @@ private static Tensor[] _ConcatGradHelper(Operation op, Tensor grad, int start_v
4243
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };
4344

4445
var concat_dim = op.inputs[dim_index];
45-
if (end_value_index == -1)
46-
end_value_index = op.inputs.Length - 1;
47-
var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();
46+
var input_values = op.inputs._inputs.Skip(start_value_index)
47+
.Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index)
48+
.ToArray();
4849

4950
var out_grads = new List<Tensor>();
5051
if (constant_op.is_constant(concat_dim))
@@ -92,10 +93,16 @@ there will be a small number of performance regressions.*/
9293
}
9394

9495
return (end_value_index <= dim_index ?
95-
out_grads.ToArray().Concat(null) :
96+
out_grads.ToArray().Concat(new Tensor[] { null }) :
9697
new Tensor[] { null }.Concat(out_grads)).ToArray();
9798
}
9899

100+
[RegisterGradient("ExpandDims")]
101+
public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads)
102+
{
103+
return new Tensor[] { _ReshapeToInput(op, grads[0]), null };
104+
}
105+
99106
/// <summary>
100107
/// Extract the shapes of a set of input tensors.
101108
/// </summary>
@@ -125,6 +132,45 @@ private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
125132
return gen_ops.shape_n(inputs);
126133
}
127134

135+
/// <summary>
136+
/// Gradient for GatherV2 op.
137+
/// </summary>
138+
/// <param name="op"></param>
139+
/// <param name="grads"></param>
140+
/// <returns></returns>
141+
[RegisterGradient("GatherV2")]
142+
public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
143+
{
144+
var grad = grads[0];
145+
var @params = op.inputs[0];
146+
ops.colocate_with(@params);
147+
148+
var params_shape = array_ops.shape(@params, out_type: tf.int64);
149+
params_shape = math_ops.cast(params_shape, tf.int32);
150+
151+
var indices = op.inputs[1];
152+
var indices_size = array_ops.expand_dims(array_ops.size(indices), 0);
153+
var axis = op.inputs[2];
154+
var axis_static = tensor_util.constant_value(axis);
155+
156+
// For axis 0 gathers, build an appropriately shaped IndexedSlices.
157+
if((int)axis_static == 0)
158+
{
159+
var params_tail_shape = params_shape[1];
160+
var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
161+
var values = array_ops.reshape(grad, values_shape);
162+
indices = array_ops.reshape(indices, indices_size);
163+
return new Tensor[]
164+
{
165+
new IndexedSlices(values, indices, params_shape),
166+
null,
167+
null
168+
};
169+
}
170+
171+
return new Tensor[] { null, null };
172+
}
173+
128174
[RegisterGradient("Reshape")]
129175
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
130176
{

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Te
106106
[RegisterGradient("Conv2D")]
107107
public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
108108
{
109-
var dilations = op.get_attr("dilations");
110-
var strides = op.get_attr("strides");
109+
var dilations = (op.get_attr("dilations") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
110+
var strides = (op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
111111
var padding = op.get_attr("padding");
112-
var explicit_paddings = op.get_attr("explicit_paddings");
112+
var explicit_paddings = (op.get_attr("explicit_paddings") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray();
113113
var use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu");
114114
var data_format = op.get_attr("data_format");
115115
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
@@ -120,21 +120,23 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
120120
{
121121
InputSizes = shape[0],
122122
Filter = op.inputs[1],
123-
Dilations = dilations == null ? null : dilations as int[],
124-
Strides = strides == null ? null : strides as int[],
123+
OutBackProp = grads[0],
124+
Dilations = dilations,
125+
Strides = strides,
125126
Padding = padding.ToString(),
126-
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
127+
ExplicitPaddings = explicit_paddings,
127128
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
128-
DataFormat = data_format.ToString()
129+
DataFormat = data_format.ToString(),
129130
}),
130131
gen_nn_ops.conv2d_backprop_filter(new Conv2dParams
131132
{
132133
Input = op.inputs[0],
133134
FilterSizes = shape[1],
134-
Dilations = dilations == null ? null : dilations as int[],
135-
Strides = strides == null ? null : strides as int[],
135+
OutBackProp = grads[0],
136+
Dilations = dilations,
137+
Strides = strides,
136138
Padding = padding.ToString(),
137-
ExplicitPaddings = explicit_paddings == null ? null : explicit_paddings as int[],
139+
ExplicitPaddings = explicit_paddings,
138140
UseCudnnOnGpu = (bool)use_cudnn_on_gpu,
139141
DataFormat = data_format.ToString()
140142
})
@@ -155,6 +157,23 @@ private static Tensor _BroadcastMul(Tensor vec, Tensor mat)
155157
return vec * mat;
156158
}
157159

160+
[RegisterGradient("MaxPool")]
161+
public static Tensor[] _MaxPoolGrad(Operation op, Tensor[] grads)
162+
{
163+
var grad = grads[0];
164+
return new Tensor[]
165+
{
166+
gen_nn_ops.max_pool_grad(
167+
op.inputs[0],
168+
op.outputs[0],
169+
grad,
170+
(op.get_attr("ksize") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(),
171+
(op.get_attr("strides") as AttrValue.Types.ListValue).I.Select(x => Convert.ToInt32(x)).ToArray(),
172+
padding: op.get_attr("padding").ToString(),
173+
data_format: op.get_attr("data_format").ToString())
174+
};
175+
}
176+
158177
/// <summary>
159178
/// Return the gradients for TopK.
160179
/// </summary>

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,23 @@ public static Tensor max_pool(Tensor input,
179179
return _op.outputs[0];
180180
}
181181

182+
public static Tensor max_pool_grad(Tensor orig_input, Tensor orig_output, Tensor grad, int[] ksize, int[] strides, string padding,
183+
string data_format= "NHWC", string name= null)
184+
{
185+
var _op = _op_def_lib._apply_op_helper("MaxPoolGrad", name: name, args: new
186+
{
187+
orig_input,
188+
orig_output,
189+
grad,
190+
ksize,
191+
strides,
192+
padding,
193+
data_format
194+
});
195+
196+
return _op.outputs[0];
197+
}
198+
182199
public static Tensor[] top_kv2(Tensor input, int k, bool sorted = true, string name = null)
183200
{
184201
var _op = _op_def_lib._apply_op_helper("TopKV2", name: name, args: new

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using Google.Protobuf.Collections;
2-
//using Newtonsoft.Json;
2+
#if GRAPH_SERIALIZE
3+
using Newtonsoft.Json;
4+
#endif
35
using System;
46
using System.Collections.Generic;
57
using System.Linq;
@@ -33,16 +35,23 @@ public partial class Operation : ITensorOrOperation
3335
private readonly IntPtr _operDesc;
3436

3537
private Graph _graph;
36-
//[JsonIgnore]
38+
public string type => OpType;
39+
40+
#if GRAPH_SERIALIZE
41+
[JsonIgnore]
42+
public Graph graph => _graph;
43+
[JsonIgnore]
44+
public int _id => _id_value;
45+
[JsonIgnore]
46+
public int _id_value;
47+
[JsonIgnore]
48+
public Operation op => this;
49+
#else
3750
public Graph graph => _graph;
38-
//[JsonIgnore]
3951
public int _id => _id_value;
40-
//[JsonIgnore]
4152
public int _id_value;
42-
43-
public string type => OpType;
44-
//[JsonIgnore]
4553
public Operation op => this;
54+
#endif
4655
public TF_DataType dtype => TF_DataType.DtInvalid;
4756
private Status status = new Status();
4857

@@ -51,7 +60,7 @@ public partial class Operation : ITensorOrOperation
5160
public string Device => c_api.StringPiece(c_api.TF_OperationDevice(_handle));
5261

5362
private NodeDef _node_def;
54-
//[JsonIgnore]
63+
[JsonIgnore]
5564
public NodeDef node_def
5665
{
5766
get

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ private static Tensor shape_internal(Tensor input, string name = null, bool opti
277277
var input_shape = tensor_util.to_shape(input_tensor.shape);
278278
if (optimize && input_tensor.NDims > -1 && input_shape.is_fully_defined())
279279
{
280-
var nd = np.array(input_tensor.shape, out_type.as_numpy_datatype());
280+
var nd = np.array(input_tensor.shape).astype(out_type.as_numpy_datatype());
281281
return constant_op.constant(nd, name: name);
282282
}
283283
}

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] c
123123
return with(ops.name_scope(name, "tuple", tensors), scope =>
124124
{
125125
name = scope;
126-
var gating_ops = tensors.Select(x => x.op).ToList();
126+
var gating_ops = tensors.Where(x => x != null).Select(x => x.op).ToList();
127127

128128
if(control_inputs != null)
129129
{
@@ -139,7 +139,10 @@ public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] c
139139
var tpl = new List<Tensor>();
140140
foreach(var t in tensors)
141141
{
142-
tpl.Add(with_dependencies(new Operation[] { gate }, t));
142+
if (t != null)
143+
tpl.Add(with_dependencies(new Operation[] { gate }, t));
144+
else
145+
tpl.Add(null);
143146
}
144147

145148
return tpl.ToArray();

0 commit comments

Comments
 (0)