Skip to content

Commit d3724a9

Browse files
committed
added make_variable overload
1 parent a4f03c2 commit d3724a9

File tree

3 files changed

+35
-20
lines changed

3 files changed

+35
-20
lines changed

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ protected virtual RefVariable add_weight(string name,
190190
var variable = _add_variable_with_custom_getter(name,
191191
shape,
192192
dtype: dtype,
193-
getter: getter, // getter == null ? base_layer_utils.make_variable : getter,
193+
getter: (getter == null) ? base_layer_utils.make_variable : getter,
194194
overwrite: true,
195195
initializer: initializer,
196196
trainable: trainable.Value);

src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,21 @@ namespace Tensorflow.Keras.Utils
88
{
99
public class base_layer_utils
1010
{
11+
/// <summary>
12+
/// Adds a new variable to the layer.
13+
/// </summary>
14+
/// <param name="name"></param>
15+
/// <param name="shape"></param>
16+
/// <param name="dtype"></param>
17+
/// <param name="initializer"></param>
18+
/// <param name="trainable"></param>
19+
/// <returns></returns>
20+
public static RefVariable make_variable(string name,
21+
int[] shape,
22+
TF_DataType dtype = TF_DataType.TF_FLOAT,
23+
IInitializer initializer = null,
24+
bool trainable = true) => make_variable(name, shape, dtype, initializer, trainable, true);
25+
1126
/// <summary>
1227
/// Adds a new variable to the layer.
1328
/// </summary>
@@ -28,7 +43,7 @@ public static RefVariable make_variable(string name,
2843

2944
ops.init_scope();
3045

31-
Func<Tensor> init_val = ()=> initializer.call(new TensorShape(shape), dtype: dtype);
46+
Func<Tensor> init_val = () => initializer.call(new TensorShape(shape), dtype: dtype);
3247

3348
var variable_dtype = dtype.as_base_dtype();
3449
var v = tf.Variable(init_val);
@@ -44,21 +59,21 @@ public static RefVariable make_variable(string name,
4459
public static string unique_layer_name(string name, Dictionary<(string, string), int> name_uid_map = null,
4560
string[] avoid_names = null, string @namespace = "", bool zero_based = false)
4661
{
47-
if(name_uid_map == null)
62+
if (name_uid_map == null)
4863
name_uid_map = get_default_graph_uid_map();
4964
if (avoid_names == null)
5065
avoid_names = new string[0];
5166

5267
string proposed_name = null;
53-
while(proposed_name == null || avoid_names.Contains(proposed_name))
68+
while (proposed_name == null || avoid_names.Contains(proposed_name))
5469
{
5570
var name_key = (@namespace, name);
5671
if (!name_uid_map.ContainsKey(name_key))
5772
name_uid_map[name_key] = 0;
5873

5974
if (zero_based)
6075
{
61-
int number = name_uid_map[name_key];
76+
int number = name_uid_map[name_key];
6277
if (number > 0)
6378
proposed_name = $"{name}_{number}";
6479
else

test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,111 +14,111 @@ public class ExamplesTest
1414
public void BasicOperations()
1515
{
1616
tf.Graph().as_default();
17-
new BasicOperations() { Enabled = true }.Train();
17+
new BasicOperations() { Enabled = true }.Run();
1818
}
1919

2020
[TestMethod]
2121
public void HelloWorld()
2222
{
2323
tf.Graph().as_default();
24-
new HelloWorld() { Enabled = true }.Train();
24+
new HelloWorld() { Enabled = true }.Run();
2525
}
2626

2727
[TestMethod]
2828
public void ImageRecognition()
2929
{
3030
tf.Graph().as_default();
31-
new HelloWorld() { Enabled = true }.Train();
31+
new HelloWorld() { Enabled = true }.Run();
3232
}
3333

3434
[Ignore]
3535
[TestMethod]
3636
public void InceptionArchGoogLeNet()
3737
{
3838
tf.Graph().as_default();
39-
new InceptionArchGoogLeNet() { Enabled = true }.Train();
39+
new InceptionArchGoogLeNet() { Enabled = true }.Run();
4040
}
4141

4242
[TestMethod]
4343
public void KMeansClustering()
4444
{
4545
tf.Graph().as_default();
46-
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
46+
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
4747
}
4848

4949
[TestMethod]
5050
public void LinearRegression()
5151
{
5252
tf.Graph().as_default();
53-
new LinearRegression() { Enabled = true }.Train();
53+
new LinearRegression() { Enabled = true }.Run();
5454
}
5555

5656
[TestMethod]
5757
public void LogisticRegression()
5858
{
5959
tf.Graph().as_default();
60-
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
60+
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
6161
}
6262

6363
[Ignore]
6464
[TestMethod]
6565
public void NaiveBayesClassifier()
6666
{
6767
tf.Graph().as_default();
68-
new NaiveBayesClassifier() { Enabled = false }.Train();
68+
new NaiveBayesClassifier() { Enabled = false }.Run();
6969
}
7070

7171
[Ignore]
7272
[TestMethod]
7373
public void NamedEntityRecognition()
7474
{
7575
tf.Graph().as_default();
76-
new NamedEntityRecognition() { Enabled = true }.Train();
76+
new NamedEntityRecognition() { Enabled = true }.Run();
7777
}
7878

7979
[TestMethod]
8080
public void NearestNeighbor()
8181
{
8282
tf.Graph().as_default();
83-
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
83+
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
8484
}
8585

8686
[Ignore]
8787
[TestMethod]
8888
public void TextClassificationTrain()
8989
{
9090
tf.Graph().as_default();
91-
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
91+
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
9292
}
9393

9494
[Ignore]
9595
[TestMethod]
9696
public void TextClassificationWithMovieReviews()
9797
{
9898
tf.Graph().as_default();
99-
new BinaryTextClassification() { Enabled = true }.Train();
99+
new BinaryTextClassification() { Enabled = true }.Run();
100100
}
101101

102102
[TestMethod]
103103
public void NeuralNetXor()
104104
{
105105
tf.Graph().as_default();
106-
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
106+
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Run());
107107
}
108108

109109
[TestMethod]
110110
public void NeuralNetXor_ImportedGraph()
111111
{
112112
tf.Graph().as_default();
113-
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
113+
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Run());
114114
}
115115

116116

117117
[TestMethod]
118118
public void ObjectDetection()
119119
{
120120
tf.Graph().as_default();
121-
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
121+
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Run());
122122
}
123123
}
124124
}

0 commit comments

Comments
 (0)