@@ -16,18 +16,14 @@ public class CondContext : ControlFlowContext
1616 /// The boolean tensor for the cond predicate
1717 /// </summary>
1818 private Tensor _pred ;
19+
1920 public Tensor pred => _pred ;
2021
2122 /// <summary>
2223 /// 0 or 1 representing this branch
2324 /// </summary>
2425 private int _branch ;
2526
26- /// <summary>
27- ///
28- /// </summary>
29- private List < string > _values = new List < string > ( ) ;
30-
3127 private Dictionary < string , Tensor > _external_values = new Dictionary < string , Tensor > ( ) ;
3228
3329 /// <summary>
@@ -66,72 +62,166 @@ public CondContext(Tensor pred,
6662 }
6763
6864 /// <summary>
69- /// Add the subgraph defined by fn() to the graph .
65+ /// Add `val` to the current context and its outer context recursively .
7066 /// </summary>
71- public ( T , Tensor ) BuildCondBranch < T > ( Func < T > fn )
67+ /// <param name="val"></param>
68+ public override Tensor AddValue ( Tensor val )
7269 {
73- // Add the subgraph defined by fn() to the graph.
74- var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
75- var original_result = fn ( ) ;
76- var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
70+ Tensor result = null ;
71+ if ( _values . Contains ( val . name ) )
72+ {
73+ // Use the real value if it comes from outer context. This is needed in
74+ // particular for nested conds.
75+ if ( _external_values . ContainsKey ( val . name ) )
76+ result = _external_values [ val . name ] ;
77+ else
78+ result = val ;
79+ }
80+ else
81+ {
82+ result = val ;
83+ _values . Add ( val . name ) ;
84+ // TODO: _outer_context
85+ if ( _outer_context != null )
86+ {
87+ result = _outer_context . AddValue ( val ) ;
88+ _values . Add ( result . name ) ;
89+ _external_values [ result . name ] = result ;
90+ }
91+ // TODO: how to do 'with' here??
92+ //with(ops.control_dependencies(null), ctrl =>
93+ //{
94+ var ( r0 , r1 ) = control_flow_ops . _SwitchRefOrTensor ( result , _pred ) ;
95+ result = new [ ] { r0 , r1 } [ _branch ] ;
96+ if ( _outer_context != null )
97+ _outer_context . AddInnerOp ( result . op ) ;
98+ //});
7799
78- //TODO: port this chunck of missing code:
79- /*
80- if len(post_summaries) > len(pre_summaries):
81- new_summaries = post_summaries[len(pre_summaries):]
82- summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
83- summary_ref[:] = pre_summaries
84- with ops.control_dependencies(new_summaries):
85- if original_result is None:
86- return no_op(), None
87- else:
88- original_result = nest.map_structure(array_ops.identity,
89- original_result)
90- */
91- if ( original_result == null )
92- return ( original_result , null ) ;
100+ result . op . graph . prevent_fetching ( result . op ) ;
101+ result . op . _set_control_flow_context ( this ) ;
93102
94- switch ( original_result )
95- {
96- case Tensor result :
97- return ( original_result , _BuildCondTensor ( new [ ] { result . op } ) ) ;
98- case Operation [ ] results :
99- return ( original_result , _BuildCondTensor ( results ) ) ;
100- case float [ ] fv :
103+ // Mark Switch output as seen by this context and any outer contexts,
104+ // just like what we do for normal op outputs in _AddOpInternal() below.
105+ IControlFlowContext ctxt = this ;
106+ while ( ctxt != null )
101107 {
102- var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
103- return ( original_result , result ) ;
108+ ctxt . values . Add ( result . name ) ;
109+ ctxt = ctxt . outer_context ;
104110 }
105- default :
106- return ( original_result , null ) ;
111+ _external_values [ val . name ] = result ;
107112 }
108- }
109-
110- public ( T [ ] , Tensor [ ] ) BuildCondBranch < T > ( Func < T [ ] > fn )
113+ return result ;
114+ }
115+
116+ /// <summary>
117+ /// Add the subgraph defined by fn() to the graph.
118+ /// </summary>
119+ public ( T , Tensor ) BuildCondBranch < T > ( Func < T > fn )
120+ {
121+ // Add the subgraph defined by fn() to the graph.
122+ var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
123+ var original_result = fn ( ) ;
124+ var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
125+
126+ //TODO: port this chunck of missing code:
127+ /*
128+ if len(post_summaries) > len(pre_summaries):
129+ new_summaries = post_summaries[len(pre_summaries):]
130+ summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION) # pylint: disable=protected-access
131+ summary_ref[:] = pre_summaries
132+ with ops.control_dependencies(new_summaries):
133+ if original_result is None:
134+ return no_op(), None
135+ else:
136+ original_result = nest.map_structure(array_ops.identity,
137+ original_result)
138+ */
139+ if ( original_result == null )
140+ return ( original_result , null ) ;
141+
142+ switch ( original_result )
143+ {
144+ case Tensor result :
145+ return ( original_result , _BuildCondTensor ( result ) ) ;
146+ case Operation op :
147+ return ( original_result , _BuildCondTensor ( op ) ) ;
148+ case float [ ] fv :
149+ {
150+ var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
151+ return ( original_result , _BuildCondTensor ( result ) ) ;
152+ }
153+ default :
154+ return ( original_result , null ) ;
155+ }
156+ }
157+
158+ public ( T [ ] , Tensor [ ] ) BuildCondBranch < T > ( Func < T [ ] > fn )
159+ {
160+ // Add the subgraph defined by fn() to the graph.
161+ var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
162+ var original_result = fn ( ) ;
163+ var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
164+
165+ switch ( original_result )
166+ {
167+ case Tensor [ ] results :
168+ return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
169+ case Operation [ ] results :
170+ return ( original_result , results . Select ( _BuildCondTensor ) . ToArray ( ) ) ;
171+ case float [ ] fv :
172+ var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
173+ return ( original_result , new Tensor [ ] { result } ) ;
174+ default :
175+ return ( original_result , new Tensor [ 0 ] ) ;
176+ }
177+ }
178+
179+ private Tensor _BuildCondTensor ( ITensorOrOperation v )
180+ {
181+ switch ( v )
182+ {
183+ case Operation op :
184+ // Use pivot as the proxy for this op.
185+ return control_flow_ops . with_dependencies ( new Operation [ ] { op } , _pivot ) ;
186+ case Tensor t :
187+ return _ProcessOutputTensor ( t ) ;
188+ default :
189+ return _ProcessOutputTensor ( ops . convert_to_tensor ( v ) ) ;
190+
191+ }
192+ }
193+
194+ /// <summary>
195+ /// Process an output tensor of a conditional branch.
196+ /// </summary>
197+ private Tensor _ProcessOutputTensor ( Tensor val )
111198 {
112- // Add the subgraph defined by fn() to the graph.
113- var pre_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
114- var original_result = fn ( ) ;
115- var post_summaries = ops . get_collection ( ops . GraphKeys . _SUMMARY_COLLECTION ) ;
116-
117- switch ( original_result )
199+ var real_val = val ;
200+ if ( ! _values . Contains ( val . name ) )
118201 {
119- case Tensor [ ] results :
120- return ( original_result , new Tensor [ ] { _BuildCondTensor ( results . Select ( t=> t . op ) . ToArray ( ) ) } ) ;
121- case Operation [ ] results :
122- return ( original_result , new Tensor [ ] { _BuildCondTensor ( results ) } ) ;
123- case float [ ] fv :
124- var result = ops . convert_to_tensor ( fv [ 0 ] ) ;
125- return ( original_result , new Tensor [ ] { result } ) ;
126- default :
127- return ( original_result , new Tensor [ 0 ] ) ;
202+ // Handle the special case of lambda: x
203+ _values . Add ( val . name ) ;
204+ if ( _outer_context != null )
205+ {
206+ real_val = _outer_context . AddValue ( val ) ;
207+ _values . Add ( real_val . name ) ;
208+ _external_values [ real_val . name ] = real_val ;
209+ }
128210 }
211+ else
212+ {
213+ Tensor external_val = null ;
214+ if ( _external_values . ContainsKey ( val . name ) )
215+ external_val = _external_values [ val . name ] ;
216+ if ( external_val != null )
217+ real_val = external_val ;
218+ }
219+ return real_val ;
129220 }
130-
131- private Tensor _BuildCondTensor ( Operation [ ] v )
221+
222+ public override void AddInnerOp ( Operation resultOp )
132223 {
133- // Use pivot as the proxy for this op.
134- return control_flow_ops . with_dependencies ( v , _pivot ) ;
224+ throw new NotImplementedException ( ) ;
135225 }
136- }
137- }
226+ }
227+ }
0 commit comments