Skip to content

Commit a6d8ffa

Browse files
authored
Fix a bug in tpu.py and xla.py that while creating an identity node for control input edges under rewrite context, the parent control flow context is lost. (#23446)
PiperOrigin-RevId: 219724472
1 parent 8ce231a commit a6d8ffa

File tree

2 files changed

+10
-16
lines changed

2 files changed

+10
-16
lines changed

tensorflow/contrib/compiler/xla.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,14 +179,11 @@ def AddOp(self, op):
179179
if external_control_inputs:
180180
# Use an identity to pull control inputs as data inputs. Note that we
181181
# ignore ops which don't have outputs. TODO(phawkins): fix that.
182-
with ops.control_dependencies(None):
183-
self.Enter()
184-
external_control_inputs = [
185-
array_ops.identity(x.outputs[0]).op
186-
for x in external_control_inputs
187-
if x.outputs
188-
]
189-
self.Exit()
182+
external_control_inputs = [
183+
array_ops.identity(x.outputs[0]).op
184+
for x in external_control_inputs
185+
if x.outputs
186+
]
190187
# pylint: disable=protected-access
191188
op._add_control_inputs(external_control_inputs)
192189
# pylint: enable=protected-access

tensorflow/contrib/tpu/python/tpu/tpu.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,14 +371,11 @@ def AddOp(self, op):
371371
if external_control_inputs:
372372
# Use an identity to pull control inputs as data inputs. Note that we
373373
# ignore ops which don't have outputs. TODO(phawkins): fix that.
374-
with ops.control_dependencies(None):
375-
self.Enter()
376-
external_control_inputs = [
377-
array_ops.identity(x.outputs[0]).op
378-
for x in external_control_inputs
379-
if x.outputs
380-
]
381-
self.Exit()
374+
external_control_inputs = [
375+
array_ops.identity(x.outputs[0]).op
376+
for x in external_control_inputs
377+
if x.outputs
378+
]
382379
# pylint: disable=protected-access
383380
op._add_control_inputs(external_control_inputs)
384381
# pylint: enable=protected-access

0 commit comments

Comments
 (0)