Skip to content

Commit f4067f2

Browse files
committed
tf.train.import_meta_graph can import CondContext.
1 parent 5627443 commit f4067f2

File tree

12 files changed

+1300
-32
lines changed

12 files changed

+1300
-32
lines changed

src/TensorFlowNET.Core/Framework/meta_graph.py.cs

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.IO;
55
using System.Linq;
66
using System.Text;
7+
using Tensorflow.Operations;
78
using static Tensorflow.CollectionDef;
89
using static Tensorflow.MetaGraphDef.Types;
910

@@ -95,15 +96,29 @@ public static (Dictionary<string, RefVariable>, ITensorOrOperation[]) import_sco
9596
}
9697
else
9798
{
98-
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
99+
foreach(var value in col.Value.BytesList.Value)
100+
{
101+
switch (col.Key)
102+
{
103+
case "cond_context":
104+
var proto = CondContextDef.Parser.ParseFrom(value);
105+
var condContext = new CondContext().from_proto(proto, import_scope);
106+
graph.add_to_collection(col.Key, condContext);
107+
break;
108+
default:
109+
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
110+
}
111+
}
99112
}
100113

101114
break;
115+
default:
116+
throw new NotImplementedException("import_scoped_meta_graph_with_return_elements");
102117
}
103118
}
104119

105-
var variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
106-
scope: scope_to_prepend_to_names) as List<RefVariable>;
120+
var variables = graph.get_collection<RefVariable>(ops.GraphKeys.GLOBAL_VARIABLES,
121+
scope: scope_to_prepend_to_names);
107122
var var_list = new Dictionary<string, RefVariable>();
108123
variables.ForEach(v => var_list[ops.strip_name_scope(v.name, scope_to_prepend_to_names)] = v);
109124

src/TensorFlowNET.Core/Graphs/Graph.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,11 @@ public object get_collection(string name, string scope = null)
412412
return _collections.ContainsKey(name) ? _collections[name] : null;
413413
}
414414

415+
public List<T> get_collection<T>(string name, string scope = null)
416+
{
417+
return _collections.ContainsKey(name) ? _collections[name] as List<T> : new List<T>();
418+
}
419+
415420
public object get_collection_ref(string name)
416421
{
417422
if (!_collections.ContainsKey(name))

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Tensorflow.Operations
88
/// <summary>
99
/// The context for the conditional construct.
1010
/// </summary>
11-
public class CondContext : ControlFlowContext
11+
public class CondContext : ControlFlowContext, IProtoBuf<CondContextDef, CondContext>
1212
{
1313

1414

@@ -35,16 +35,20 @@ public class CondContext : ControlFlowContext
3535
/// <param name="name">Name of the `CondContext` python object.</param>
3636
/// <param name="context_def"></param>
3737
/// <param name="import_scope"></param>
38-
public CondContext(Tensor pred,
39-
Tensor pivot,
40-
int branch,
38+
public CondContext(Tensor pred = null,
39+
Tensor pivot = null,
40+
int? branch = null,
4141
string name = "cond_text",
42-
object context_def = null,
42+
CondContextDef context_def = null,
4343
string import_scope = null)
4444
{
45+
if (pred == null && context_def == null) return;
46+
4547
_name = ops.get_default_graph().unique_name(name);
46-
if (context_def != null)
47-
throw new NotImplementedException("CondContext context_def is not null");
48+
if (context_def != null)
49+
{
50+
_init_from_proto(context_def, import_scope: import_scope);
51+
}
4852
else
4953
{
5054
// Initializes the default fields.
@@ -61,6 +65,18 @@ public CondContext(Tensor pred,
6165
}
6266
}
6367

68+
private void _init_from_proto(CondContextDef context_def, string import_scope = null)
69+
{
70+
var g = ops.get_default_graph();
71+
_name = ops.prepend_name_scope(context_def.ContextName, import_scope);
72+
var p1 = ops.prepend_name_scope(context_def.PredName, import_scope);
73+
_pred = g.as_graph_element(p1) as Tensor;
74+
var p2 = ops.prepend_name_scope(context_def.PivotName, import_scope);
75+
_pivot = g.as_graph_element(p2) as Tensor;
76+
_branch = context_def.Branch;
77+
__init__(values_def: context_def.ValuesDef, import_scope: import_scope);
78+
}
79+
6480
/// <summary>
6581
/// Add `val` to the current context and its outer context recursively.
6682
/// </summary>
@@ -230,6 +246,22 @@ private Tensor _ProcessOutputTensor(Tensor val)
230246
public override void AddInnerOp(Operation resultOp)
231247
{
232248
throw new NotImplementedException();
233-
}
249+
}
250+
251+
public CondContextDef to_proto(string export_scope)
252+
{
253+
throw new NotImplementedException();
254+
}
255+
256+
public CondContext from_proto(CondContextDef proto, string import_scope)
257+
{
258+
var ret = new CondContext(context_def: proto, import_scope: import_scope);
259+
260+
ret.Enter();
261+
foreach (var nested_def in proto.NestedContexts)
262+
throw new NotImplementedException("");
263+
ret.Exit();
264+
return ret;
265+
}
234266
}
235267
}

src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ public abstract class ControlFlowContext : Python, IPython, IControlFlowContext
3232
protected Stack<IControlFlowContext> _context_stack;
3333
protected IControlFlowContext _outer_context;
3434

35+
protected Dictionary<string, ITensorOrOperation> _external_values;
36+
3537
public ControlFlowContext()
3638
{
3739
_context_stack = new Stack<IControlFlowContext>();
@@ -40,15 +42,43 @@ public ControlFlowContext()
4042
public string name { get => _name; }
4143
protected string _name;
4244

43-
public void __init__()
45+
public void __init__(ValuesDef values_def = null, string import_scope = null)
4446
{
45-
47+
_outer_context = ops.get_default_graph()._get_control_flow_context();
48+
if (values_def != null)
49+
_init_values_from_proto(values_def, import_scope: import_scope);
4650
}
4751

4852
public void __enter__()
4953
{
5054
}
5155

56+
/// <summary>
57+
/// Initializes values and external_values from `ValuesDef` protocol buffer.
58+
/// </summary>
59+
/// <param name="values_def"></param>
60+
/// <param name="import_scope"></param>
61+
protected void _init_values_from_proto(ValuesDef values_def, string import_scope = null)
62+
{
63+
_external_values = new Dictionary<string, ITensorOrOperation>();
64+
foreach (var value in values_def.Values)
65+
_values.Add(value);
66+
var g = ops.get_default_graph();
67+
foreach(var value in values_def.ExternalValues)
68+
{
69+
var k = ops.prepend_name_scope(value.Key, import_scope);
70+
var v = value.Value;
71+
_external_values[k] = g.as_graph_element(ops.prepend_name_scope(v, import_scope));
72+
}
73+
74+
var op_names = _values.Where(x => !_external_values.ContainsKey(x))
75+
.Select(x => x.Split(':')[0])
76+
.ToArray();
77+
78+
foreach (var op in op_names)
79+
(g.as_graph_element(op) as Operation)._set_control_flow_context(this);
80+
}
81+
5282
public void __exit__()
5383
{
5484
}

src/TensorFlowNET.Core/Operations/Operation.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public void _update_input(int index, Tensor tensor)
287287
// Reset cached inputs.
288288
_inputs = null;// new InputList(new Tensor[] { tensor }); // is this right? original code: self._inputs_val=None
289289
// TODO: implement below code dependencies
290-
// c_api.TF_UpdateEdge(graph, output, input, status);
290+
c_api.TF_UpdateEdge(graph, output, input, status);
291291
}
292292

293293
private void _assert_same_graph(Tensor tensor)

src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ raise TypeError("pred must not be a Python bool")
330330
tensor.op.graph.prevent_fetching(tensor.op);
331331

332332
// Build the graph for the true branch in a new context.
333-
var context_t = new CondContext(pred, pivot_1, branch: 1);
333+
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
334334
ITensorOrOperation orig_res_t;
335335
Tensor res_t;
336336
try
@@ -343,7 +343,7 @@ raise TypeError("pred must not be a Python bool")
343343
context_t.Exit();
344344
}
345345
// Build the graph for the false branch in a new context.
346-
var context_f = new CondContext(pred, pivot_2, branch: 0);
346+
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
347347
ITensorOrOperation orig_res_f;
348348
Tensor res_f;
349349
try
@@ -411,13 +411,13 @@ public static Tensor[] cond<T>(Tensor pred,
411411
tensor.op.graph.prevent_fetching(tensor.op);
412412

413413
// Build the graph for the true branch in a new context.
414-
var context_t = new CondContext(pred, pivot_1, branch: 1);
414+
var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
415415
context_t.Enter();
416416
var (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
417417
context_t.Exit();
418418

419419
// Build the graph for the false branch in a new context.
420-
var context_f = new CondContext(pred, pivot_2, branch: 0);
420+
var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
421421
context_f.Enter();
422422
var (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
423423
context_f.Exit();

0 commit comments

Comments
 (0)