Skip to content

Commit a7b7670

Browse files
committed
CondContext: implemented missing functionality
1 parent 26e78cd commit a7b7670

File tree

9 files changed

+381
-89
lines changed

9 files changed

+381
-89
lines changed

src/TensorFlowNET.Core/Graphs/Graph.Operation.cs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,27 @@ public ITensorOrOperation _get_operation_by_tf_operation(IntPtr tf_oper)
4141
{
4242
var op_name = Marshal.PtrToStringAnsi(c_api.TF_OperationName(tf_oper));
4343
return _get_operation_by_name_unsafe(op_name);
44-
}
45-
44+
}
45+
46+
/// <summary>
47+
/// Creates an `Operation` in this graph from the supplied TF_Operation.
48+
///
49+
/// This method is like create_op() except the new Operation is constructed
50+
/// using `c_op`. The returned Operation will have `c_op` as its _c_op
51+
/// field.This is used to create Operation objects around TF_Operations created
52+
/// indirectly by the C API(e.g.by TF_ImportGraphDef, TF_FinishWhile).
53+
///
54+
/// This function does not call Operation._control_flow_post_processing or
55+
/// Graph._control_dependencies_for_inputs (since the inputs may not be
56+
/// available yet). The caller is responsible for calling these methods.
57+
/// </summary>
58+
/// <param name="c_op">a wrapped TF_Operation</param>
59+
/// <param name="compute_device">(Optional.) If True, device functions will be executed
60+
/// to compute the device property of the Operation.</param>
61+
/// <returns>An `Operation` object.</returns>
4662
public Operation _create_op_from_tf_operation(IntPtr c_op, bool compute_device = true)
4763
{
48-
var ret = new Operation(c_op);
64+
var ret = new Operation(c_op, this);
4965
_add_op(ret);
5066

5167
var name_key = ret.name.ToLower();

src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs

Lines changed: 151 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,14 @@ public class CondContext : ControlFlowContext
1616
/// The boolean tensor for the cond predicate
1717
/// </summary>
1818
private Tensor _pred;
19+
1920
public Tensor pred => _pred;
2021

2122
/// <summary>
2223
/// 0 or 1 representing this branch
2324
/// </summary>
2425
private int _branch;
2526

26-
/// <summary>
27-
///
28-
/// </summary>
29-
private List<string> _values = new List<string>();
30-
3127
private Dictionary<string, Tensor> _external_values = new Dictionary<string, Tensor>();
3228

3329
/// <summary>
@@ -66,72 +62,166 @@ public CondContext(Tensor pred,
6662
}
6763

6864
/// <summary>
69-
/// Add the subgraph defined by fn() to the graph.
65+
/// Add `val` to the current context and its outer context recursively.
7066
/// </summary>
71-
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
67+
/// <param name="val"></param>
68+
public override Tensor AddValue(Tensor val)
7269
{
73-
// Add the subgraph defined by fn() to the graph.
74-
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
75-
var original_result = fn();
76-
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
70+
Tensor result = null;
71+
if (_values.Contains(val.name))
72+
{
73+
// Use the real value if it comes from outer context. This is needed in
74+
// particular for nested conds.
75+
if (_external_values.ContainsKey(val.name))
76+
result = _external_values[val.name];
77+
else
78+
result = val;
79+
}
80+
else
81+
{
82+
result = val;
83+
_values.Add(val.name);
84+
// TODO: _outer_context
85+
if (_outer_context != null)
86+
{
87+
result = _outer_context.AddValue(val);
88+
_values.Add(result.name);
89+
_external_values[result.name] = result;
90+
}
91+
// TODO: how to do 'with' here??
92+
//with(ops.control_dependencies(null), ctrl =>
93+
//{
94+
var (r0, r1) = control_flow_ops._SwitchRefOrTensor(result, _pred);
95+
result = new[]{r0, r1}[_branch];
96+
if (_outer_context != null)
97+
_outer_context.AddInnerOp(result.op);
98+
//});
7799

78-
//TODO: port this chunck of missing code:
79-
/*
80-
if len(post_summaries) > len(pre_summaries):
81-
new_summaries = post_summaries[len(pre_summaries):]
82-
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
83-
summary_ref[:] = pre_summaries
84-
with ops.control_dependencies(new_summaries):
85-
if original_result is None:
86-
return no_op(), None
87-
else:
88-
original_result = nest.map_structure(array_ops.identity,
89-
original_result)
90-
*/
91-
if (original_result == null)
92-
return (original_result, null);
100+
result.op.graph.prevent_fetching(result.op);
101+
result.op._set_control_flow_context(this);
93102

94-
switch (original_result)
95-
{
96-
case Tensor result:
97-
return (original_result, _BuildCondTensor(new[] { result.op }));
98-
case Operation[] results:
99-
return (original_result, _BuildCondTensor(results));
100-
case float[] fv:
103+
// Mark Switch output as seen by this context and any outer contexts,
104+
// just like what we do for normal op outputs in _AddOpInternal() below.
105+
IControlFlowContext ctxt = this;
106+
while (ctxt != null)
101107
{
102-
var result = ops.convert_to_tensor(fv[0]);
103-
return (original_result, result );
108+
ctxt.values.Add(result.name);
109+
ctxt = ctxt.outer_context;
104110
}
105-
default:
106-
return (original_result, null);
111+
_external_values[val.name] = result;
107112
}
108-
}
109-
110-
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
113+
return result;
114+
}
115+
116+
/// <summary>
117+
/// Add the subgraph defined by fn() to the graph.
118+
/// </summary>
119+
public (T, Tensor) BuildCondBranch<T>(Func<T> fn)
120+
{
121+
// Add the subgraph defined by fn() to the graph.
122+
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
123+
var original_result = fn();
124+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
125+
126+
//TODO: port this chunck of missing code:
127+
/*
128+
if len(post_summaries) > len(pre_summaries):
129+
new_summaries = post_summaries[len(pre_summaries):]
130+
summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
131+
summary_ref[:] = pre_summaries
132+
with ops.control_dependencies(new_summaries):
133+
if original_result is None:
134+
return no_op(), None
135+
else:
136+
original_result = nest.map_structure(array_ops.identity,
137+
original_result)
138+
*/
139+
if (original_result == null)
140+
return (original_result, null);
141+
142+
switch (original_result)
143+
{
144+
case Tensor result:
145+
return (original_result, _BuildCondTensor(result));
146+
case Operation op:
147+
return (original_result, _BuildCondTensor(op));
148+
case float[] fv:
149+
{
150+
var result = ops.convert_to_tensor(fv[0]);
151+
return (original_result, _BuildCondTensor(result));
152+
}
153+
default:
154+
return (original_result, null);
155+
}
156+
}
157+
158+
public (T[], Tensor[]) BuildCondBranch<T>(Func<T[]> fn)
159+
{
160+
// Add the subgraph defined by fn() to the graph.
161+
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
162+
var original_result = fn();
163+
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
164+
165+
switch (original_result)
166+
{
167+
case Tensor[] results:
168+
return (original_result, results.Select(_BuildCondTensor).ToArray());
169+
case Operation[] results:
170+
return (original_result, results.Select(_BuildCondTensor).ToArray());
171+
case float[] fv:
172+
var result = ops.convert_to_tensor(fv[0]);
173+
return (original_result, new Tensor[] { result });
174+
default:
175+
return (original_result, new Tensor[0]);
176+
}
177+
}
178+
179+
private Tensor _BuildCondTensor(ITensorOrOperation v)
180+
{
181+
switch (v)
182+
{
183+
case Operation op:
184+
// Use pivot as the proxy for this op.
185+
return control_flow_ops.with_dependencies(new Operation[] { op }, _pivot);
186+
case Tensor t:
187+
return _ProcessOutputTensor(t);
188+
default:
189+
return _ProcessOutputTensor(ops.convert_to_tensor(v));
190+
191+
}
192+
}
193+
194+
/// <summary>
195+
/// Process an output tensor of a conditional branch.
196+
/// </summary>
197+
private Tensor _ProcessOutputTensor(Tensor val)
111198
{
112-
// Add the subgraph defined by fn() to the graph.
113-
var pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
114-
var original_result = fn();
115-
var post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION);
116-
117-
switch (original_result)
199+
var real_val = val;
200+
if (!_values.Contains(val.name))
118201
{
119-
case Tensor[] results:
120-
return (original_result, new Tensor[] { _BuildCondTensor(results.Select(t=>t.op).ToArray())});
121-
case Operation[] results:
122-
return (original_result, new Tensor[] { _BuildCondTensor (results) });
123-
case float[] fv:
124-
var result = ops.convert_to_tensor(fv[0]);
125-
return (original_result, new Tensor[] { result });
126-
default:
127-
return (original_result, new Tensor[0]);
202+
// Handle the special case of lambda: x
203+
_values.Add(val.name);
204+
if (_outer_context != null)
205+
{
206+
real_val = _outer_context.AddValue(val);
207+
_values.Add(real_val.name);
208+
_external_values[real_val.name] = real_val;
209+
}
128210
}
211+
else
212+
{
213+
Tensor external_val = null;
214+
if (_external_values.ContainsKey(val.name))
215+
external_val = _external_values[val.name];
216+
if (external_val != null)
217+
real_val = external_val;
218+
}
219+
return real_val;
129220
}
130-
131-
private Tensor _BuildCondTensor(Operation[] v)
221+
222+
public override void AddInnerOp(Operation resultOp)
132223
{
133-
// Use pivot as the proxy for this op.
134-
return control_flow_ops.with_dependencies(v, _pivot);
224+
throw new NotImplementedException();
135225
}
136-
}
137-
}
226+
}
227+
}

0 commit comments

Comments
 (0)