@@ -247,7 +247,7 @@ def fn():
247247 dtype = tf_dtype , tensor_array_name = "foo" , size = 3 )
248248 lengths = constant_op .constant ([1 , 1 , 1 ])
249249 w0 = ta .split (
250- convert ([[1.0 , 101.0 ], [2.0 , 201 .0 ], [3.0 , 301 .0 ]]),
250+ convert ([[1.0 , 101.0 ], [2.0 , 121 .0 ], [3.0 , 127 .0 ]]),
251251 lengths = lengths )
252252 r0 = w0 .read (0 )
253253 r1 = w0 .read (1 )
@@ -256,14 +256,13 @@ def fn():
256256
257257 d0 , d1 , d2 = self .evaluate (xla .compile (fn ))
258258 self .assertAllEqual (convert ([[1.0 , 101.0 ]]), d0 )
259- self .assertAllEqual (convert ([[2.0 , 201 .0 ]]), d1 )
260- self .assertAllEqual (convert ([[3.0 , 301 .0 ]]), d2 )
259+ self .assertAllEqual (convert ([[2.0 , 121 .0 ]]), d1 )
260+ self .assertAllEqual (convert ([[3.0 , 127 .0 ]]), d2 )
261261
262- # Disable temporarily due to b/195023333
263- # @test_util.disable_control_flow_v2("b/122315872 (split)")
264- # def testTensorArraySplitRead(self):
265- # for dtype in self.numeric_tf_types:
266- # self._testTensorArraySplitRead(dtype)
262+ @test_util .disable_control_flow_v2 ("b/122315872 (split)" )
263+ def testTensorArraySplitRead (self ):
264+ for dtype in self .numeric_tf_types :
265+ self ._testTensorArraySplitRead (dtype )
267266
268267 @test_util .disable_control_flow_v2 ("TensorArray.grad is not supported in v2" )
269268 def testTensorGradArrayWriteRead (self ):
0 commit comments