Skip to content

Commit f13e35d

Browse files
committed
extend gradient function capability.
1 parent 4ff993b commit f13e35d

File tree

8 files changed

+107
-64
lines changed

8 files changed

+107
-64
lines changed

docs/source/Gradient.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,15 @@
11
# Chapter. Gradient
22

3+
### Register custom gradient function
4+
5+
TF.NET is extensible which can be added custom gradient function.
6+
7+
```csharp
8+
// define gradient function
9+
ops.RegisterGradientFunction("ConcatV2", (oper, out_grads) =>
10+
{
11+
var grad = grads[0];
12+
return new Tensor[]{ };
13+
});
14+
```
15+
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Gradients
6+
{
7+
public class RegisterGradient : Attribute
8+
{
9+
public string Name { get; set; }
10+
11+
public RegisterGradient(string name)
12+
{
13+
Name = name;
14+
}
15+
}
16+
}

src/TensorFlowNET.Core/Gradients/array_grad.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
1010
/// <summary>
1111
/// tensorflow\python\ops\array_grad.py
1212
/// </summary>
13+
[RegisterGradient("array_grad")]
1314
public class array_grad
1415
{
16+
[RegisterGradient("ConcatV2")]
1517
public static Tensor[] _ConcatGradV2(Operation op, Tensor[] grads)
1618
{
1719
var grad = grads[0];
@@ -123,12 +125,13 @@ private static Tensor[] _ExtractInputShapes(Tensor[] inputs)
123125
return gen_ops.shape_n(inputs);
124126
}
125127

126-
128+
[RegisterGradient("Reshape")]
127129
public static Tensor[] _ReshapeGrad(Operation op, Tensor[] grads)
128130
{
129131
return new Tensor[] { array_ops.reshape(grads[0], array_ops.shape(op.inputs[0])), null };
130132
}
131133

134+
[RegisterGradient("Squeeze")]
132135
public static Tensor[] _SqueezeGrad(Operation op, Tensor[] grads)
133136
{
134137
return new Tensor[] { _ReshapeToInput(op, grads[0]) };
@@ -139,6 +142,7 @@ private static Tensor _ReshapeToInput(Operation op, Tensor grad)
139142
return array_ops.reshape(grad, array_ops.shape(op.inputs[0]));
140143
}
141144

145+
[RegisterGradient("Transpose")]
142146
public static Tensor[] _TransposeGrad(Operation op, Tensor[] grads)
143147
{
144148
var p = op.inputs[1];

src/TensorFlowNET.Core/Gradients/control_flow_grad.py.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ public Tensor[] _SwitchGrad(Tensor op, Tensor[] grads)
6969
// false_grad = switch(grad[0], op.inputs[1])[0]
7070
// true_grad = switch(grad[1], op.inputs[1])[1]
7171
// return merge([false_grad, true_grad])[0], None
72-
}
73-
72+
}
73+
7474
/// <summary>
7575
/// Gradients for a Merge op are calculated using a Switch op.
7676
/// </summary>
77+
[RegisterGradient("Merge")]
7778
public static Tensor[] _MergeGrad(Operation op, Tensor[] grads)
7879
{
7980
var grad = grads[0];

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ namespace Tensorflow.Gradients
1010
/// <summary>
1111
/// Gradients for operators defined in math_ops.py.
1212
/// </summary>
13+
[RegisterGradient("math_grad")]
1314
public class math_grad
1415
{
16+
[RegisterGradient("Add")]
1517
public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
1618
{
1719
var x = op.inputs[0];
@@ -32,6 +34,7 @@ public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
3234
return new Tensor[] { r1, r2 };
3335
}
3436

37+
[RegisterGradient("DivNoNan")]
3538
public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
3639
{
3740
var grad = grads[0];
@@ -59,6 +62,7 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
5962
/// <param name="op"></param>
6063
/// <param name="grads"></param>
6164
/// <returns></returns>
65+
[RegisterGradient("Exp")]
6266
public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
6367
{
6468
var grad = grads[0];
@@ -69,11 +73,13 @@ public static Tensor[] _ExpGrad(Operation op, Tensor[] grads)
6973
});
7074
}
7175

76+
[RegisterGradient("Identity")]
7277
public static Tensor[] _IdGrad(Operation op, Tensor[] grads)
7378
{
7479
return new Tensor[] { grads[0] };
7580
}
7681

82+
[RegisterGradient("Log")]
7783
public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
7884
{
7985
var grad = grads[0];
@@ -84,6 +90,7 @@ public static Tensor[] _LogGrad(Operation op, Tensor[] grads)
8490
});
8591
}
8692

93+
[RegisterGradient("Mul")]
8794
public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
8895
{
8996
var x = op.inputs[0];
@@ -112,6 +119,7 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
112119
return new Tensor[] { reshape1, reshape2 };
113120
}
114121

122+
[RegisterGradient("MatMul")]
115123
public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
116124
{
117125
var grad = grads[0];
@@ -145,6 +153,7 @@ public static Tensor[] _MatMulGrad(Operation op, Tensor[] grads)
145153
return new Tensor[] { grad_a, grad_b };
146154
}
147155

156+
[RegisterGradient("Mean")]
148157
public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
149158
{
150159
var grad = grads[0];
@@ -159,6 +168,7 @@ public static Tensor[] _MeanGrad(Operation op, Tensor[] grads)
159168
return new Tensor[] { math_ops.truediv(sum_grad, math_ops.cast(factor, sum_grad.dtype)), null };
160169
}
161170

171+
[RegisterGradient("Neg")]
162172
public static Tensor[] _NegGrad(Operation op, Tensor[] grads)
163173
{
164174
return new Tensor[] { -grads[0] };
@@ -169,6 +179,7 @@ private static Tensor _safe_shape_div(Tensor x, Tensor y)
169179
return math_ops.floordiv(x, gen_math_ops.maximum(y, 1));
170180
}
171181

182+
[RegisterGradient("Sub")]
172183
public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
173184
{
174185
var grad = grads[0];
@@ -198,6 +209,7 @@ public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad
198209
!x_shape.Contains(-1);
199210
}
200211

212+
[RegisterGradient("Sum")]
201213
public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
202214
{
203215
var grad = grads[0];
@@ -231,6 +243,7 @@ public static Tensor[] _SumGrad(Operation op, Tensor[] grads)
231243
return new Tensor[] { gen_array_ops.tile(grad, tile_scaling), null };
232244
}
233245

246+
[RegisterGradient("RealDiv")]
234247
public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
235248
{
236249
var grad = grads[0];
@@ -254,6 +267,7 @@ public static Tensor[] _RealDivGrad(Operation op, Tensor[] grads)
254267
return new Tensor[] { reshape2, reshape1 };
255268
}
256269

270+
[RegisterGradient("Sigmoid")]
257271
public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
258272
{
259273
var grad = grads[0];
@@ -266,6 +280,7 @@ public static Tensor[] _SigmoidGrad(Operation op, Tensor[] grads)
266280
});
267281
}
268282

283+
[RegisterGradient("Square")]
269284
public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
270285
{
271286
var grad = grads[0];
@@ -279,6 +294,7 @@ public static Tensor[] _SquareGrad(Operation op, Tensor[] grads)
279294
});
280295
}
281296

297+
[RegisterGradient("Pow")]
282298
public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
283299
{
284300
var grad = grads[0];

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace Tensorflow.Gradients
99
/// <summary>
1010
///
1111
/// </summary>
12+
[RegisterGradient("math_grad")]
1213
public class nn_grad
1314
{
1415
/// <summary>
@@ -17,6 +18,7 @@ public class nn_grad
1718
/// <param name="op"></param>
1819
/// <param name="grad"></param>
1920
/// <returns></returns>
21+
[RegisterGradient("BiasAdd")]
2022
public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
2123
{
2224
var grad = grads[0];
@@ -25,6 +27,7 @@ public static Tensor[] _BiasAddGrad(Operation op, Tensor[] grads)
2527
return new Tensor[] { grad, bias_add_grad };
2628
}
2729

30+
[RegisterGradient("Relu")]
2831
public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
2932
{
3033
return new Tensor[] { gen_nn_ops.relu_grad(grads[0], op.outputs[0]) };
@@ -36,6 +39,7 @@ public static Tensor[] _ReluGrad(Operation op, Tensor[] grads)
3639
/// <param name="op"></param>
3740
/// <param name="grads"></param>
3841
/// <returns></returns>
42+
[RegisterGradient("Softmax")]
3943
public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
4044
{
4145
var grad_softmax = grads[0];
@@ -54,6 +58,7 @@ public static Tensor[] _SoftmaxGrad(Operation op, Tensor[] grads)
5458
/// <param name="grad_loss"></param>
5559
/// <param name="grad_grad"></param>
5660
/// <returns></returns>
61+
[RegisterGradient("SoftmaxCrossEntropyWithLogits")]
5762
public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
5863
{
5964
var grad_loss = grads[0];
@@ -74,6 +79,7 @@ public static Tensor[] _SoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[]
7479
};
7580
}
7681

82+
[RegisterGradient("SparseSoftmaxCrossEntropyWithLogits")]
7783
public static Tensor[] _SparseSoftmaxCrossEntropyWithLogitsGrad(Operation op, Tensor[] grads)
7884
{
7985
var sparse_softmax_grad_without_gradient = array_ops.prevent_gradient(
@@ -111,6 +117,7 @@ private static Tensor _BroadcastMul(Tensor vec, Tensor mat)
111117
/// <param name="op"></param>
112118
/// <param name="grads"></param>
113119
/// <returns></returns>
120+
[RegisterGradient("TopK")]
114121
public static Tensor[] _TopKGrad(Operation op, Tensor[] grads)
115122
{
116123
var grad = grads[0];
Lines changed: 45 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,65 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Reflection;
35
using System.Text;
46
using Tensorflow.Gradients;
57

68
namespace Tensorflow
79
{
810
public partial class ops
911
{
12+
static Dictionary<string, Func<Operation, Tensor[], Tensor[]>> gradientFunctions = null;
13+
14+
/// <summary>
15+
/// Regiter new gradient function
16+
/// </summary>
17+
/// <param name="name">operation type</param>
18+
/// <param name="func">function delegate</param>
19+
public static void RegisterGradientFunction(string name, Func<Operation, Tensor[], Tensor[]> func)
20+
{
21+
if(gradientFunctions == null)
22+
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();
23+
24+
gradientFunctions[name] = func;
25+
}
26+
1027
public static Func<Operation, Tensor[], Tensor[]> get_gradient_function(Operation op)
1128
{
1229
if (op.inputs == null) return null;
1330

14-
// map tensorflow\python\ops\math_grad.py
15-
return (oper, out_grads) =>
31+
if (gradientFunctions == null)
1632
{
17-
// Console.WriteLine($"get_gradient_function: {oper.type} '{oper.name}'");
33+
gradientFunctions = new Dictionary<string, Func<Operation, Tensor[], Tensor[]>>();
1834

19-
switch (oper.type)
35+
var gradGroups = Assembly.GetExecutingAssembly()
36+
.GetTypes()
37+
.Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
38+
.ToArray();
39+
40+
foreach (var g in gradGroups)
2041
{
21-
case "Add":
22-
return math_grad._AddGrad(oper, out_grads);
23-
case "BiasAdd":
24-
return nn_grad._BiasAddGrad(oper, out_grads);
25-
case "ConcatV2":
26-
return array_grad._ConcatGradV2(oper, out_grads);
27-
case "DivNoNan":
28-
return math_grad._DivNoNanGrad(oper, out_grads);
29-
case "Exp":
30-
return math_grad._ExpGrad(oper, out_grads);
31-
case "Identity":
32-
return math_grad._IdGrad(oper, out_grads);
33-
case "Log":
34-
return math_grad._LogGrad(oper, out_grads);
35-
case "MatMul":
36-
return math_grad._MatMulGrad(oper, out_grads);
37-
case "Merge":
38-
return control_flow_grad._MergeGrad(oper, out_grads);
39-
case "Mul":
40-
return math_grad._MulGrad(oper, out_grads);
41-
case "Mean":
42-
return math_grad._MeanGrad(oper, out_grads);
43-
case "Neg":
44-
return math_grad._NegGrad(oper, out_grads);
45-
case "Sum":
46-
return math_grad._SumGrad(oper, out_grads);
47-
case "Sub":
48-
return math_grad._SubGrad(oper, out_grads);
49-
case "Pow":
50-
return math_grad._PowGrad(oper, out_grads);
51-
case "RealDiv":
52-
return math_grad._RealDivGrad(oper, out_grads);
53-
case "Reshape":
54-
return array_grad._ReshapeGrad(oper, out_grads);
55-
case "Relu":
56-
return nn_grad._ReluGrad(oper, out_grads);
57-
case "Sigmoid":
58-
return math_grad._SigmoidGrad(oper, out_grads);
59-
case "Square":
60-
return math_grad._SquareGrad(oper, out_grads);
61-
case "Squeeze":
62-
return array_grad._SqueezeGrad(oper, out_grads);
63-
case "Softmax":
64-
return nn_grad._SoftmaxGrad(oper, out_grads);
65-
case "SoftmaxCrossEntropyWithLogits":
66-
return nn_grad._SoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
67-
case "SparseSoftmaxCrossEntropyWithLogits":
68-
return nn_grad._SparseSoftmaxCrossEntropyWithLogitsGrad(oper, out_grads);
69-
case "Transpose":
70-
return array_grad._TransposeGrad(oper, out_grads);
71-
case "TopK":
72-
case "TopKV2":
73-
return nn_grad._TopKGrad(oper, out_grads);
74-
default:
75-
throw new NotImplementedException($"get_gradient_function {oper.type}");
42+
var methods = g.GetMethods().Where(x => x.GetCustomAttribute<RegisterGradient>() != null)
43+
.ToArray();
44+
45+
foreach (var m in methods)
46+
{
47+
RegisterGradientFunction(m.GetCustomAttribute<RegisterGradient>().Name,
48+
(oper, out_grads) =>
49+
g.InvokeMember(m.Name,
50+
BindingFlags.InvokeMethod,
51+
null,
52+
null,
53+
args: new object[] { oper, out_grads }) as Tensor[]
54+
);
55+
}
7656
}
77-
};
57+
}
58+
59+
if (!gradientFunctions.ContainsKey(op.type))
60+
throw new NotImplementedException($"can't get graident function through get_gradient_function {op.type}");
61+
62+
return gradientFunctions[op.type];
7863
}
7964
}
8065
}

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
2020
<AssemblyVersion>0.8.1.0</AssemblyVersion>
2121
<PackageReleaseNotes>Changes since v0.8:
2222

23-
Removed global static graph instance.</PackageReleaseNotes>
23+
1. Removed global static graph instance.
24+
2. Provide custom gradient function.</PackageReleaseNotes>
2425
<LangVersion>7.2</LangVersion>
2526
<FileVersion>0.8.1.0</FileVersion>
2627
</PropertyGroup>

0 commit comments

Comments
 (0)