Skip to content

Commit e02d6d0

Browse files
committed
ongoing tf.keras.backend.cs
1 parent f13e35d commit e02d6d0

File tree

7 files changed

+255
-10
lines changed

7 files changed

+255
-10
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace Tensorflow.Keras
8+
{
9+
public abstract class BackendBase
10+
{
11+
TF_DataType _FLOATX = dtypes.float32;
12+
float _EPSILON = 1e-7f;
13+
ImageDataFormat _IMAGE_DATA_FORMAT = ImageDataFormat.channels_last;
14+
15+
16+
public float epsilon() => _EPSILON;
17+
18+
public void set_epsilon(float e) => _EPSILON = e;
19+
20+
public TF_DataType floatx() => _FLOATX;
21+
22+
public void set_floatx(TF_DataType floatx) => _FLOATX = floatx;
23+
24+
public NDArray cast_to_floatx(NDArray x) => np.array(x, dtype: _FLOATX.as_numpy_datatype());
25+
26+
public ImageDataFormat image_data_format() => _IMAGE_DATA_FORMAT;
27+
28+
public void set_image_data_format(ImageDataFormat data_format) => _IMAGE_DATA_FORMAT = data_format;
29+
30+
public ImageDataFormat normalize_data_format(object value = null)
31+
{
32+
if (value == null)
33+
value = _IMAGE_DATA_FORMAT;
34+
if (value.GetType() == typeof(ImageDataFormat))
35+
return (ImageDataFormat)value;
36+
else if (value.GetType() == typeof(string))
37+
{
38+
ImageDataFormat dataFormat;
39+
if(Enum.TryParse((string)value, true, out dataFormat))
40+
{
41+
if (Enum.IsDefined(typeof(ImageDataFormat), dataFormat) | dataFormat.ToString().Contains(","))
42+
return dataFormat;
43+
}
44+
}
45+
throw new Exception("The `data_format` argument must be one of \"channels_first\", \"channels_last\". Received: " + value.ToString());
46+
}
47+
48+
//Legacy Methods
49+
50+
public void set_image_dim_ordering(ImageDimOrder dim_ordering)
51+
{
52+
if (dim_ordering == ImageDimOrder.th)
53+
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_first;
54+
else if (dim_ordering == ImageDimOrder.tf)
55+
_IMAGE_DATA_FORMAT = ImageDataFormat.channels_last;
56+
else
57+
throw new Exception("Unknown dim_ordering:"+ dim_ordering);
58+
}
59+
60+
public ImageDimOrder image_dim_ordering()
61+
{
62+
if (_IMAGE_DATA_FORMAT == ImageDataFormat.channels_first)
63+
return ImageDimOrder.th;
64+
else
65+
return ImageDimOrder.tf;
66+
}
67+
}
68+
public enum ImageDimOrder
69+
{
70+
tf,
71+
th
72+
}
73+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Tensorflow.Keras
2+
{
3+
public enum GraphLearningPhase
4+
{
5+
train_mode = 1,
6+
test_mode = 0
7+
}
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
namespace Tensorflow.Keras
2+
{
3+
public enum ImageDataFormat
4+
{
5+
channels_last,
6+
channels_first
7+
}
8+
}

src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,14 +95,14 @@ public static string unique_layer_name(string name, Dictionary<(string, string),
9595
{
9696
var graph = ops.get_default_graph();
9797
Dictionary<(string, string), int> name_uid_map = null;
98-
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph.graph_key))
98+
if (backend.PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
9999
{
100-
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key];
100+
name_uid_map = backend.PER_GRAPH_LAYER_NAME_UIDS[graph];
101101
}
102102
else
103103
{
104104
name_uid_map = new Dictionary<(string, string), int>();
105-
backend.PER_GRAPH_LAYER_NAME_UIDS[graph.graph_key] = name_uid_map;
105+
backend.PER_GRAPH_LAYER_NAME_UIDS[graph] = name_uid_map;
106106
}
107107

108108
return name_uid_map;
Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,51 @@
11
using System;
22
using System.Collections.Generic;
33
using System.Text;
4+
using System.Runtime.CompilerServices;
5+
using static Tensorflow.Python;
46

57
namespace Tensorflow.Keras
68
{
7-
public class backend
9+
public class backend : BackendBase
810
{
11+
/* ---------------------------------------- KERAS BACKEND NATIVE OBJECTS ---------------------------------------- */
12+
public static Func<Array, double> py_sum = sum;
13+
public static Func<Array, bool> py_all = all;
14+
//Func<Array, bool> py_any = any;
15+
//Func<double, double, double, IEnumerable<double>> py_slice = slice;
16+
17+
public static Session _SESSION = Tensorflow.tf.defaultSession;
18+
public static Graph _GRAPH = null;
19+
public static Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
20+
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
21+
public static bool _MANUAL_VAR_INIT = false;
22+
public static List<string> _LOCAL_DEVICES = null;
23+
/* -------------------------------------- KERAS BACKEND NATIVE OBJECTS END -------------------------------------- */
24+
925
/// <summary>
1026
/// A global dictionary mapping graph objects to an index of counters used
1127
/// for various layer names in each graph.
1228
/// Allows to give unique autogenerated names to layers, in a graph-specific way.
1329
/// </summary>
14-
public static Dictionary<string, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<string, Dictionary<(string, string), int>>();
30+
public static Dictionary<Graph, Dictionary<(string, string), int>> PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
1531
public static Dictionary<string, RefVariable> _GRAPH_VARIABLES = new Dictionary<string, RefVariable>();
32+
public static Dictionary<string, Optimizer> _GRAPH_TF_OPTIMIZERS = new Dictionary<string, Optimizer>();
33+
34+
public static _DummyEagerGraph _DUMMY_EAGER_GRAPH = new _DummyEagerGraph();
35+
1636
public static void track_variable(RefVariable v)
1737
{
1838
var graph = v.graph;
1939
_GRAPH_VARIABLES[graph.graph_key] = v;
2040
}
2141

22-
public static Tensor placeholder(int[] shape = null,
23-
int ndim = -1,
24-
TF_DataType dtype = TF_DataType.DtInvalid,
25-
bool sparse = false,
42+
public static Tensor placeholder(int[] shape = null,
43+
int ndim = -1,
44+
TF_DataType dtype = TF_DataType.DtInvalid,
45+
bool sparse = false,
2646
string name = null)
2747
{
28-
if(sparse)
48+
if (sparse)
2949
{
3050
throw new NotImplementedException("placeholder sparse is true");
3151
}
@@ -39,5 +59,56 @@ public static Graph get_graph()
3959
{
4060
return ops.get_default_graph();
4161
}
62+
63+
public static int get_uid(string prefix, string @namespace = "")
64+
{
65+
var graph = tf.get_default_graph();
66+
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
67+
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
68+
PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)] += 1;
69+
70+
return PER_GRAPH_LAYER_NAME_UIDS[graph][(@namespace, prefix)];
71+
}
72+
public static int get_uid((string, string) name)
73+
{
74+
var graph = tf.get_default_graph();
75+
if (!PER_GRAPH_LAYER_NAME_UIDS.ContainsKey(graph))
76+
PER_GRAPH_LAYER_NAME_UIDS.Add(graph, new defaultdict<(string, string), int>());
77+
PER_GRAPH_LAYER_NAME_UIDS[graph][(name)] += 1;
78+
79+
return PER_GRAPH_LAYER_NAME_UIDS[graph][name];
80+
}
81+
public static void reset_uids() => PER_GRAPH_LAYER_NAME_UIDS = new Dictionary<Graph, Dictionary<(string, string), int>>();
82+
public static void clear_session()
83+
{
84+
ops.reset_default_graph();
85+
reset_uids();
86+
_SESSION = null;
87+
var phase = tf.placeholder_with_default(false, new int[] { }, name: "keras_learning_phase");
88+
_GRAPH_LEARNING_PHASES = new Dictionary<Graph, GraphLearningPhase>();
89+
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = 0;
90+
}
91+
public static void manual_variable_initialization(bool value)
92+
{
93+
_MANUAL_VAR_INIT = value;
94+
}
95+
public static GraphLearningPhase learning_phase()
96+
{
97+
var graph = tf.get_default_graph();
98+
if (_GRAPH_LEARNING_PHASES.ContainsKey(graph))
99+
{
100+
var phase = tf.placeholder_with_default(false, shape: new int[] { }, name: "keras_learning_phase");
101+
_GRAPH_LEARNING_PHASES[graph] = 0;
102+
}
103+
return _GRAPH_LEARNING_PHASES[graph];
104+
}
105+
public static void set_learning_phase(bool value)
106+
{
107+
_GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0);
108+
}
109+
110+
111+
public class _DummyEagerGraph
112+
{ }
42113
}
43114
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System.Collections.Generic;
2+
3+
namespace System.Collections.Generic
4+
{
5+
public class defaultdict<TKey, TValue> : Dictionary<TKey, TValue> where TValue : new()
6+
{
7+
public new TValue this[TKey key]
8+
{
9+
get
10+
{
11+
TValue val;
12+
if(!TryGetValue(key, out val))
13+
{
14+
val = default(TValue);
15+
Add(key, val);
16+
}
17+
return val;
18+
}
19+
set { base[key] = value; }
20+
}
21+
}
22+
}

src/TensorFlowNET.Core/Python.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,69 @@ public static Dictionary<string, object> ConvertToDict(object dyn)
184184
return dictionary;
185185
}
186186

187+
188+
public static bool all(IEnumerable enumerable)
189+
{
190+
foreach (var e1 in enumerable)
191+
{
192+
if (!Convert.ToBoolean(e1))
193+
return false;
194+
}
195+
return true;
196+
}
197+
198+
public static bool any(IEnumerable enumerable)
199+
{
200+
foreach (var e1 in enumerable)
201+
{
202+
if (Convert.ToBoolean(e1))
203+
return true;
204+
}
205+
return false;
206+
}
207+
208+
public static double sum(IEnumerable enumerable)
209+
{
210+
var typedef = new Type[] { typeof(double), typeof(int), typeof(float) };
211+
var sum = 0.0d;
212+
foreach (var e1 in enumerable)
213+
{
214+
if (!typedef.Contains(e1.GetType()))
215+
throw new Exception("Numeric array expected");
216+
sum += (double)e1;
217+
}
218+
return sum;
219+
}
220+
221+
public static double sum<TKey, TValue>(Dictionary<TKey, TValue> values)
222+
{
223+
return sum(values.Keys);
224+
}
225+
226+
public static IEnumerable<double> slice(double start, double end, double step = 1)
227+
{
228+
for (double i = start; i < end; i += step)
229+
yield return i;
230+
}
231+
232+
public static IEnumerable<float> slice(float start, float end, float step = 1)
233+
{
234+
for (float i = start; i < end; i += step)
235+
yield return i;
236+
}
237+
238+
public static IEnumerable<int> slice(int start, int end, int step = 1)
239+
{
240+
for (int i = start; i < end; i += step)
241+
yield return i;
242+
}
243+
244+
public static IEnumerable<int> slice(int range)
245+
{
246+
for (int i = 0; i < range; i++)
247+
yield return i;
248+
}
249+
187250
public static bool hasattr(object obj, string key)
188251
{
189252
var __type__ = (obj).GetType();

0 commit comments

Comments
 (0)