Skip to content

Commit 16da1c2

Browse files
committed
add array_ops.unique, math_ops.sqrt, Optimizer._apply_sparse_duplicate_indices
1 parent 0595ea1 commit 16da1c2

File tree

15 files changed

+309
-31
lines changed

15 files changed

+309
-31
lines changed

src/TensorFlowNET.Core/Framework/IndexedSlices.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,26 @@ public class IndexedSlices : CompositeTensor
1111
{
1212
Tensor _values;
1313
public Tensor values => _values;
14+
Tensor _indices;
15+
public Tensor indices => _indices;
16+
Tensor _dense_shape;
17+
public Tensor dense_shape => _dense_shape;
18+
19+
public string name => _values.name;
20+
21+
public string device => _values.Device;
22+
23+
public Operation op => _values.op;
24+
25+
public TF_DataType dtype => _values.dtype;
26+
27+
public Graph graph => _values.graph;
1428

1529
public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
1630
{
17-
31+
_values = values;
32+
_indices = indices;
33+
_dense_shape = dense_shape;
1834
}
1935

2036
public static implicit operator Tensor(IndexedSlices indexedSlices)

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,13 @@ there will be a small number of performance regressions.*/
8383
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
8484
new Tensor[] { tf.constant(1), tf.constant(-1) });
8585
var squeeze_sizes = array_ops.squeeze(slice);
86-
out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
86+
out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
8787
}
8888
else
8989
{
90-
var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
90+
var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes);
9191
foreach (var (begin, size) in zip(offset, sizes))
92-
out_grads.Add(gen_ops.slice(grad, begin, size));
92+
out_grads.Add(gen_array_ops.slice(grad, begin, size));
9393
}
9494

9595
return (end_value_index <= dim_index ?
@@ -129,7 +129,7 @@ private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
129129
if (fully_known)
130130
return sizes;
131131
else
132-
return gen_ops.shape_n(inputs);
132+
return gen_array_ops.shape_n(inputs);
133133
}
134134

135135
/// <summary>

src/TensorFlowNET.Core/Gradients/gradients_util.cs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,7 @@ public static Tensor[] _GradientsHelper(Tensor[] ys,
9393
{
9494
// generate gradient subgraph for op.
9595
var op = queue.Dequeue();
96-
if(op.name == "embedding/ExpandDims")
97-
{
9896

99-
}
10097
_maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops);
10198
//if (loop_state != null)
10299
//loop_state.EnterGradWhileContext(op, before: true);
@@ -311,16 +308,22 @@ private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>>
311308
// Aggregate multiple gradients, and convert [] to None.
312309
if (out_grad.Count > 0)
313310
{
311+
string used = "";
314312
if (out_grad.Count < 2)
315313
{
316-
string used = "nop";
314+
used = "nop";
317315
if (out_grad.Count == 0)
318316
{
319317
throw new ValueError("_AggregatedGrads out_grad.Length == 0");
320318
}
321319

322320
return_grads[i] = out_grad[0];
323321
}
322+
else
323+
{
324+
used = "add_n";
325+
out_grads[i] = new List<Tensor> { _MultiDeviceAddN(out_grad.ToArray(), gradient_uid) };
326+
}
324327
}
325328
else
326329
{
@@ -331,6 +334,38 @@ private static Tensor[] _AggregatedGrads(Dictionary<string, List<List<Tensor>>>
331334
return return_grads;
332335
}
333336

337+
/// <summary>
338+
/// Adds tensors from potentially multiple devices.
339+
/// </summary>
340+
/// <param name="tensor_list"></param>
341+
/// <param name="gradient_uid"></param>
342+
/// <returns></returns>
343+
private static Tensor _MultiDeviceAddN(Tensor[] tensor_list, string gradient_uid)
344+
{
345+
// Basic function structure comes from control_flow_ops.group().
346+
// Sort tensors according to their devices.
347+
var tensors_on_device = new Dictionary<string, List<Tensor>>();
348+
349+
foreach (var tensor in tensor_list)
350+
{
351+
if (!tensors_on_device.ContainsKey(tensor.Device))
352+
tensors_on_device[tensor.Device] = new List<Tensor>();
353+
354+
tensors_on_device[tensor.Device].Add(tensor);
355+
}
356+
357+
// For each device, add the tensors on that device first.
358+
var summands = new List<Tensor>();
359+
foreach(var dev in tensors_on_device.Keys)
360+
{
361+
var tensors = tensors_on_device[dev];
362+
ops._colocate_with_for_gradient(tensors[0].op, gradient_uid, ignore_existing: true);
363+
summands.Add(math_ops.add_n(tensors.ToArray()));
364+
}
365+
366+
return math_ops.add_n(summands.ToArray());
367+
}
368+
334369
/// <summary>
335370
/// The set of ops that terminate the gradient computation.
336371
/// </summary>

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ public static Tensor one_hot(Tensor indices, int depth,
276276
});
277277
}
278278

279+
public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null)
280+
=> gen_array_ops.unique(x, out_idx: out_idx, name: name);
281+
279282
public static Tensor where(Tensor condition, object x = null, object y = null, string name = null)
280283
{
281284
if( x == null && y == null)

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ public static Tensor concat_v2(Tensor[] values, int axis, string name = null)
2626
return _op.outputs[0];
2727
}
2828

29+
public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string name = null)
30+
{
31+
var _op = _op_def_lib._apply_op_helper("ConcatOffset", name: name, args: new { concat_dim, shape });
32+
33+
return _op.outputs;
34+
}
35+
2936
/// <summary>
3037
/// Returns a diagonal tensor with a given diagonal values.
3138
/// </summary>
@@ -205,6 +212,21 @@ public static Tensor reshape(Tensor tensor, int[] shape, string name = null)
205212
return _op.outputs[0];
206213
}
207214

215+
/// <summary>
216+
/// Finds unique elements in a 1-D tensor.
217+
/// </summary>
218+
/// <param name="x"></param>
219+
/// <param name="out_idx"></param>
220+
/// <param name="name"></param>
221+
/// <returns></returns>
222+
public static (Tensor, Tensor) unique(Tensor x, TF_DataType out_idx = TF_DataType.TF_INT32, string name = null)
223+
{
224+
var _op = _op_def_lib._apply_op_helper("Unique", name, new { x, out_idx });
225+
// TODO
226+
throw new NotImplementedException("_result = _UniqueOutput._make(_result)");
227+
// return _op.outputs[0];
228+
}
229+
208230
public static Tensor where()
209231
{
210232
throw new NotImplementedException("where");
@@ -271,6 +293,26 @@ public static Tensor size(Tensor input, TF_DataType out_type = TF_DataType.TF_IN
271293
return _op.outputs[0];
272294
}
273295

296+
/// <summary>
297+
/// Return a slice from 'input'
298+
/// </summary>
299+
/// <param name="input"></param>
300+
/// <param name="begin"></param>
301+
/// <param name="size"></param>
302+
/// <param name="name"></param>
303+
/// <returns></returns>
304+
public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name = null)
305+
{
306+
var _op = _op_def_lib._apply_op_helper("Slice", name, new { input, begin, size });
307+
return _op.outputs[0];
308+
}
309+
310+
public static Tensor[] split(Tensor axis, Tensor value, int num_split, string name = null)
311+
{
312+
var _op = _op_def_lib._apply_op_helper("Split", name, new { split_dim = axis, value, num_split });
313+
return _op.outputs;
314+
}
315+
274316
public static Tensor tile(Tensor input, Tensor multiples, string name = null)
275317
{
276318
var _op = _op_def_lib._apply_op_helper("Tile", name, new { input, multiples });

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,19 @@ public static Tensor _all(Tensor input, Tensor axis, bool keep_dims = false, str
1616
return _op.outputs[0];
1717
}
1818

19+
/// <summary>
20+
/// Add all input tensors element wise.
21+
/// </summary>
22+
/// <param name="inputs"></param>
23+
/// <param name="name"></param>
24+
/// <returns></returns>
25+
public static Tensor add_n(Tensor[] inputs, string name = null)
26+
{
27+
var _op = _op_def_lib._apply_op_helper("AddN", name, args: new { inputs });
28+
29+
return _op.outputs[0];
30+
}
31+
1932
/// <summary>
2033
/// Returns the index with the largest value across dimensions of a tensor.
2134
/// </summary>
@@ -198,6 +211,20 @@ public static Tensor cosh(Tensor x, string name = null)
198211
return _op.outputs[0];
199212
}
200213

214+
/// <summary>
215+
/// Computes the sum along segments of a tensor.
216+
/// </summary>
217+
/// <param name="data"></param>
218+
/// <param name="segment_ids"></param>
219+
/// <param name="num_segments"></param>
220+
/// <param name="name"></param>
221+
/// <returns></returns>
222+
public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null)
223+
{
224+
var _op = _op_def_lib._apply_op_helper("UnsortedSegmentSum", name, new { data, segment_ids, num_segments });
225+
return _op.outputs[0];
226+
}
227+
201228
public static Tensor tan(Tensor x, string name = null)
202229
{
203230
var _op = _op_def_lib._apply_op_helper("Tan", name, args: new { x });

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ public static Tensor add_n(Tensor[] inputs, string name = null)
4444
return array_ops.identity(values, name: name);
4545
return values;
4646
}
47-
throw new NotImplementedException("math_ops add_n n > 1");
48-
// return gen_math_ops.add_n(inputs, name: name);
47+
48+
return gen_math_ops.add_n(inputs, name: name);
4949
}
5050

5151
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
@@ -126,6 +126,9 @@ public static Tensor div_no_nan(Tensor x, Tensor y, string name = null)
126126
public static Tensor equal<Tx, Ty>(Tx x, Ty y, string name = null)
127127
=> gen_math_ops.equal(x, y, name: name);
128128

129+
public static Tensor sqrt(Tensor x, string name = null)
130+
=> gen_math_ops.sqrt(x, name: name);
131+
129132
public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
130133
=> gen_math_ops.mul(x, y, name: name);
131134

@@ -319,6 +322,17 @@ public static Tensor reduce_min(Tensor input_tensor, int[] axis = null, bool kee
319322
return _may_reduce_to_scalar(keepdims, axis, min);
320323
}
321324

325+
/// <summary>
326+
/// Computes the sum along segments of a tensor.
327+
/// </summary>
328+
/// <param name="data"></param>
329+
/// <param name="segment_ids"></param>
330+
/// <param name="num_segments"></param>
331+
/// <param name="name"></param>
332+
/// <returns></returns>
333+
public static Tensor unsorted_segment_sum(Tensor data, Tensor segment_ids, Tensor num_segments, string name = null)
334+
=> gen_math_ops.unsorted_segment_sum(data, segment_ids, num_segments, name: name);
335+
322336
/// <summary>
323337
/// Casts a tensor to type `int32`.
324338
/// </summary>

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
<AssemblyName>TensorFlow.NET</AssemblyName>
66
<RootNamespace>Tensorflow</RootNamespace>
77
<TargetTensorFlow>1.14.0</TargetTensorFlow>
8-
<Version>0.8.1</Version>
8+
<Version>0.8.2</Version>
99
<Authors>Haiping Chen</Authors>
1010
<Company>SciSharp STACK</Company>
11-
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
11+
<GeneratePackageOnBuild>false</GeneratePackageOnBuild>
1212
<Copyright>Apache 2.0</Copyright>
1313
<RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl>
1414
<RepositoryType>git</RepositoryType>
@@ -17,14 +17,15 @@
1717
<PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags>
1818
<Description>Google's TensorFlow full binding in .NET Standard.
1919
Docs: https://tensorflownet.readthedocs.io</Description>
20-
<AssemblyVersion>0.8.1.0</AssemblyVersion>
20+
<AssemblyVersion>0.8.2.0</AssemblyVersion>
2121
<PackageReleaseNotes>Changes since v0.8:
2222

2323
1. Remove global static graph instance.
2424
2. Provide custom gradient function.
25-
3. Add gradient function for Conv2D.</PackageReleaseNotes>
25+
3. Add gradient function for Conv2D.
26+
4. Fix bug for Transfer Learning example.</PackageReleaseNotes>
2627
<LangVersion>7.2</LangVersion>
27-
<FileVersion>0.8.1.0</FileVersion>
28+
<FileVersion>0.8.2.0</FileVersion>
2829
</PropertyGroup>
2930

3031
<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
@@ -42,6 +43,10 @@ Docs: https://tensorflownet.readthedocs.io</Description>
4243
<None Remove="runtimes\**" />
4344
</ItemGroup>
4445

46+
<ItemGroup>
47+
<Compile Remove="Operations\gen_ops.cs" />
48+
</ItemGroup>
49+
4550
<ItemGroup>
4651
<None Remove="Protobuf\README.md" />
4752
</ItemGroup>
Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using Tensorflow.Framework;
5+
using static Tensorflow.Python;
46

57
namespace Tensorflow.Train
68
{
@@ -10,9 +12,10 @@ namespace Tensorflow.Train
1012
/// </summary>
1113
public class AdamOptimizer : Optimizer
1214
{
13-
private float _beta1;
14-
private float _beta2;
15-
private float _epsilon;
15+
float _beta1;
16+
float _beta2;
17+
float _epsilon;
18+
Tensor _lr_t, _beta1_t, _beta2_t, _epsilon_t;
1619

1720
public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1e-8f, bool use_locking = false, string name = "Adam")
1821
: base(learning_rate, use_locking, name)
@@ -21,5 +24,51 @@ public AdamOptimizer(float learning_rate, float beta1 = 0.9f, float beta2 = 0.99
2124
_beta2 = beta2;
2225
_epsilon = epsilon;
2326
}
27+
28+
public override Operation _apply_sparse(IndexedSlices grad, RefVariable var)
29+
{
30+
return _apply_sparse_shared(grad.values, var, grad.indices, (x, i, v) =>
31+
{
32+
return state_ops.scatter_add(x, i, v, use_locking: _use_locking);
33+
});
34+
}
35+
36+
private Operation _apply_sparse_shared(Tensor grad, RefVariable var, Tensor indices, Func<RefVariable, Tensor, Tensor, Tensor> scatter_add)
37+
{
38+
var (beta1_power_v, beta2_power_v) = _get_beta_accumulators();
39+
Tensor beta1_power = math_ops.cast(beta1_power_v, var.dtype.as_base_dtype());
40+
Tensor beta2_power = math_ops.cast(beta2_power_v, var.dtype.as_base_dtype());
41+
var lr_t = math_ops.cast(_lr_t, var.dtype.as_base_dtype());
42+
var beta1_t = math_ops.cast(_beta1_t, var.dtype.as_base_dtype());
43+
var beta2_t = math_ops.cast(_beta2_t, var.dtype.as_base_dtype());
44+
var epsilon_t = math_ops.cast(_epsilon_t, var.dtype.as_base_dtype());
45+
var lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power));
46+
var m = get_slot(var, "m");
47+
var m_scaled_g_values = grad * (1 - beta1_t);
48+
var m_t = state_ops.assign(m, m * beta1_t, use_locking: _use_locking);
49+
with(ops.control_dependencies(new[] { m_t }), delegate
50+
{
51+
m_t = scatter_add(m, indices, m_scaled_g_values);
52+
});
53+
54+
var v = get_slot(var, "v");
55+
var v_scaled_g_values = (grad * grad) * (1 - beta2_t);
56+
var v_t = state_ops.assign(v, v * beta2_t, use_locking: _use_locking);
57+
with(ops.control_dependencies(new[] { v_t }), delegate
58+
{
59+
v_t = scatter_add(v, indices, v_scaled_g_values);
60+
});
61+
var v_sqrt = math_ops.sqrt(v_t);
62+
var var_update = state_ops.assign_sub(var, lr * m_t / (v_sqrt + epsilon_t), use_locking: _use_locking);
63+
return control_flow_ops.group(new[] { var_update, m_t, v_t });
64+
}
65+
66+
private (RefVariable, RefVariable) _get_beta_accumulators()
67+
{
68+
ops.init_scope();
69+
var graph = ops.get_default_graph();
70+
return (_get_non_slot_variable("beta1_power", graph: graph),
71+
_get_non_slot_variable("beta2_power", graph: graph));
72+
}
2473
}
2574
}

0 commit comments

Comments
 (0)