@@ -556,27 +556,35 @@ class VariableScopeWithPartitioningTest(tf.test.TestCase):
556556
557557 def testResultNameMatchesRequested (self ):
558558 with tf .variable_scope ("scope0" , partitioner = axis0_into2_partitioner ):
559- v_concat = tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
559+ v = tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
560+ self .assertEqual (v .name , "scope0/name0" )
561+ v_concat = v .as_tensor ()
560562 self .assertEqual (v_concat .name , "scope0/name0:0" )
561563 variables = tf .get_collection (tf .GraphKeys .VARIABLES )
562- concat_variables = tf .get_collection (tf .GraphKeys .CONCATENATED_VARIABLES )
563- self .assertTrue (v_concat .name in [x .name for x in concat_variables ])
564- self .assertTrue ("scope0/name0_0:0" in [x .name for x in variables ])
565- self .assertTrue ("scope0/name0_1:0" in [x .name for x in variables ])
566- self .assertFalse ("scope0/name0_2:0" in [x .name for x in variables ])
564+ self .assertTrue ("scope0/name0/part_0:0" in [x .name for x in variables ])
565+ self .assertTrue ("scope0/name0/part_1:0" in [x .name for x in variables ])
566+ self .assertFalse ("scope0/name0/part_2:0" in [x .name for x in variables ])
567567
568568 def testBreaksIfPartitioningChanges (self ):
569569 with tf .variable_scope ("scope0" , partitioner = axis0_into2_partitioner ):
570570 tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
571571
572- with tf .variable_scope ("scope0" , partitioner = axis0_into3_partitioner ):
572+ with tf .variable_scope ("scope0" ,
573+ partitioner = axis0_into3_partitioner ,
574+ reuse = True ):
573575 with self .assertRaisesRegexp (
574- ValueError , "Partitioner returned a different partitioning" ):
576+ ValueError ,
577+ "Trying to reuse partitioned variable .* but specified partitions .* "
578+ "and found partitions .*" ):
575579 tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
576580
577- with tf .variable_scope ("scope0" , partitioner = axis0_into1_partitioner ):
581+ with tf .variable_scope ("scope0" ,
582+ partitioner = axis0_into1_partitioner ,
583+ reuse = True ):
578584 with self .assertRaisesRegexp (
579- ValueError , "Partitioner returned a different partitioning" ):
585+ ValueError ,
586+ "Trying to reuse partitioned variable .* but specified partitions .* "
587+ "and found partitions .*" ):
580588 tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
581589
582590 def testReturnsExistingConcatenatedValueIfReuse (self ):
@@ -586,6 +594,14 @@ def testReturnsExistingConcatenatedValueIfReuse(self):
586594 v_concat_2 = tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
587595 self .assertEqual (v_concat , v_concat_2 )
588596
597+ def testAllowsReuseWithoutPartitioner (self ):
598+ with tf .variable_scope ("scope0" , partitioner = axis0_into2_partitioner ):
599+ v = tf .get_variable ("name0" , shape = (3 , 1 , 1 ))
600+ with tf .variable_scope ("scope0" , reuse = True ):
601+ v_reused = tf .get_variable ("name0" )
602+
603+ self .assertEqual (v , v_reused )
604+
589605 def testPartitionConcatenatesAlongCorrectAxis (self ):
590606 def _part_axis_0 (** unused_kwargs ):
591607 return (2 , 1 , 1 )
@@ -600,13 +616,13 @@ def _part_axis_1(**unused_kwargs):
600616 self .assertEqual (v0 .get_shape (), (2 , 2 , 2 ))
601617 self .assertEqual (v1 .get_shape (), (2 , 2 , 2 ))
602618
603- n0_0 = tf .get_default_graph ().get_tensor_by_name ("root/n0_0 :0" )
604- n0_1 = tf .get_default_graph ().get_tensor_by_name ("root/n0_1 :0" )
619+ n0_0 = tf .get_default_graph ().get_tensor_by_name ("root/n0/part_0 :0" )
620+ n0_1 = tf .get_default_graph ().get_tensor_by_name ("root/n0/part_1 :0" )
605621 self .assertEqual (n0_0 .get_shape (), (1 , 2 , 2 ))
606622 self .assertEqual (n0_1 .get_shape (), (1 , 2 , 2 ))
607623
608- n1_0 = tf .get_default_graph ().get_tensor_by_name ("root/n1_0 :0" )
609- n1_1 = tf .get_default_graph ().get_tensor_by_name ("root/n1_1 :0" )
624+ n1_0 = tf .get_default_graph ().get_tensor_by_name ("root/n1/part_0 :0" )
625+ n1_1 = tf .get_default_graph ().get_tensor_by_name ("root/n1/part_1 :0" )
610626 self .assertEqual (n1_0 .get_shape (), (2 , 1 , 2 ))
611627 self .assertEqual (n1_1 .get_shape (), (2 , 1 , 2 ))
612628
0 commit comments