Skip to content

Commit 5a73e69

Browse files
committed
add Tensor[] pattern match for ops.name_scope.
1 parent e56e5d3 commit 5a73e69

File tree

4 files changed

+10
-35
lines changed

4 files changed

+10
-35
lines changed

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public Tensor __call__(Tensor inputs,
3737
VariableScope scope = null)
3838
{
3939
_set_scope(scope);
40-
_graph = ops._get_graph_from_inputs(new List<Tensor> { inputs }, graph: _graph);
40+
_graph = ops._get_graph_from_inputs(new Tensor[] { inputs }, graph: _graph);
4141

4242
variable_scope scope_context_manager = null;
4343
if (built)

src/TensorFlowNET.Core/ops.name_scope.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34
using System.Text;
45
using Tensorflow.Eager;
56

@@ -37,8 +38,11 @@ public void __enter__()
3738
_name = _name == null ? _default_name : _name;
3839

3940
Graph g = null;
40-
if (_values is List<Tensor> values)
41-
g = _get_graph_from_inputs(values);
41+
42+
if (_values is List<Tensor> vList)
43+
g = _get_graph_from_inputs(vList.ToArray());
44+
else if (_values is Tensor[] vArray)
45+
g = _get_graph_from_inputs(vArray);
4246

4347
if (g == null)
4448
g = get_default_graph();

src/TensorFlowNET.Core/ops.py.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,10 @@ public static void reset_default_graph()
102102
default_graph = tf.Graph();
103103
}
104104

105+
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
106+
=> _get_graph_from_inputs(op_input_list: op_input_list);
105107

106-
public static Graph _get_graph_from_inputs(List<Tensor> op_input_list, Graph graph = null)
108+
public static Graph _get_graph_from_inputs(Tensor[] op_input_list, Graph graph = null)
107109
{
108110
foreach(var op_input in op_input_list)
109111
{

test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -203,37 +203,6 @@ protected virtual bool RunWithBuiltGraph(Session session, Graph graph)
203203
return (train_x, valid_x, train_y, valid_y);
204204
}
205205

206-
//private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
207-
//{
208-
// Console.WriteLine("Splitting in Training and Testing data...");
209-
// var stopwatch = Stopwatch.StartNew();
210-
// int len = x.Length;
211-
// int train_size = int.Parse((len * (1 - test_size)).ToString());
212-
// var random = new Random(17);
213-
214-
// // we collect indices of labels
215-
// var labels = new Dictionary<int, HashSet<int>>();
216-
// var shuffled_indices = random.Shuffle<int>(range(len).ToArray());
217-
// foreach (var i in shuffled_indices)
218-
// {
219-
// var label = y[i];
220-
// if (!labels.ContainsKey(i))
221-
// labels[label] = new HashSet<int>();
222-
// labels[label].Add(i);
223-
// }
224-
225-
// var train_x = new int[train_size][];
226-
// var valid_x = new int[len - train_size][];
227-
// var train_y = new int[train_size];
228-
// var valid_y = new int[len - train_size];
229-
230-
// FillWithShuffledLabels(x, y, train_x, train_y, random, labels);
231-
// FillWithShuffledLabels(x, y, valid_x, valid_y, random, labels);
232-
233-
// Console.WriteLine("\tDONE " + stopwatch.Elapsed);
234-
// return (train_x, valid_x, train_y, valid_y);
235-
//}
236-
237206
private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
238207
{
239208
int i = 0;

0 commit comments

Comments
 (0)