Skip to content

Commit eff582d

Browse files
theweihotensorflower-gardener
authored andcommitted
Skeleton for PartitionedVariable class
Change: 123347356
1 parent 99671f3 commit eff582d

5 files changed

Lines changed: 264 additions & 118 deletions

File tree

tensorflow/python/kernel_tests/partitioned_variables_test.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222
from six.moves import xrange # pylint: disable=redefined-builtin
2323
import tensorflow as tf
2424

25-
from tensorflow.python.ops import variable_scope
26-
27-
# pylint: disable=protected-access
28-
get_partitioned_variable_list = variable_scope._get_partitioned_variable_list
29-
# pylint: enable=protected-access
30-
3125

3226
class PartitionerCreatorsTest(tf.test.TestCase):
3327

@@ -39,8 +33,9 @@ def _testVariableAxisSizePartitioner(self, name, axis, max_shard_bytes,
3933
axis=axis, max_shard_bytes=max_shard_bytes, max_shards=max_shards)
4034

4135
with tf.variable_scope("root", partitioner=partitioner):
42-
v0_list, v0_part = get_partitioned_variable_list(
43-
name, dtype=tf.float32, shape=(4, 8, 16, 32))
36+
v0 = tf.get_variable(name, dtype=tf.float32, shape=(4, 8, 16, 32))
37+
v0_list = v0._get_variable_list()
38+
v0_part = v0._get_partitions()
4439
self.assertEqual(len(v0_list), expected_axis_shards)
4540
self.assertAllEqual(v0_part, expected_partitions)
4641

@@ -118,10 +113,13 @@ def testVariableAxisSizePartitioner(self):
118113
axis=3, max_shard_bytes=32768, bytes_per_string_element=8)
119114

120115
with tf.variable_scope("root", partitioner=partitioner_axis3_str):
121-
v3str_list, v3str_part = get_partitioned_variable_list(
116+
v3str = tf.get_variable(
122117
"v3str",
123-
initializer=np.array([""] * 4*8*16*32).reshape(4, 8, 16, 32),
124-
dtype=tf.string, shape=(4, 8, 16, 32))
118+
initializer=np.array([""] * 4 * 8 * 16 * 32).reshape(4, 8, 16, 32),
119+
dtype=tf.string,
120+
shape=(4, 8, 16, 32))
121+
v3str_list = v3str._get_variable_list()
122+
v3str_part = v3str._get_partitions()
125123

126124
# Now the estimated bytes_per_slice = 4*8*16*bytes_per_string_element
127125
# which is equal to 4096. Setting a max_shard_bytes of 32768
@@ -191,9 +189,11 @@ def testName(self):
191189
with self.test_session():
192190
rnd_par = tf.constant([[1, 2, 3, 4], [5, 6, 7, 8]])
193191
with tf.variable_scope("hola") as vs:
194-
vs1 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par)
192+
vs1 = tf.create_partitioned_variables(
193+
[2, 4], [1, 2], rnd_par, dtype=tf.int32)
195194
with tf.variable_scope(vs, reuse=True):
196-
vs2 = tf.create_partitioned_variables([2, 4], [1, 2], rnd_par)
195+
vs2 = tf.create_partitioned_variables(
196+
[2, 4], [1, 2], rnd_par, dtype=tf.int32)
197197
tf.initialize_all_variables().run()
198198
var1_name = vs1[0]._save_slice_info.full_name
199199
var2_name = vs2[0]._save_slice_info.full_name

tensorflow/python/kernel_tests/variable_scope_test.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -556,27 +556,35 @@ class VariableScopeWithPartitioningTest(tf.test.TestCase):
556556

557557
def testResultNameMatchesRequested(self):
558558
with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner):
559-
v_concat = tf.get_variable("name0", shape=(3, 1, 1))
559+
v = tf.get_variable("name0", shape=(3, 1, 1))
560+
self.assertEqual(v.name, "scope0/name0")
561+
v_concat = v.as_tensor()
560562
self.assertEqual(v_concat.name, "scope0/name0:0")
561563
variables = tf.get_collection(tf.GraphKeys.VARIABLES)
562-
concat_variables = tf.get_collection(tf.GraphKeys.CONCATENATED_VARIABLES)
563-
self.assertTrue(v_concat.name in [x.name for x in concat_variables])
564-
self.assertTrue("scope0/name0_0:0" in [x.name for x in variables])
565-
self.assertTrue("scope0/name0_1:0" in [x.name for x in variables])
566-
self.assertFalse("scope0/name0_2:0" in [x.name for x in variables])
564+
self.assertTrue("scope0/name0/part_0:0" in [x.name for x in variables])
565+
self.assertTrue("scope0/name0/part_1:0" in [x.name for x in variables])
566+
self.assertFalse("scope0/name0/part_2:0" in [x.name for x in variables])
567567

568568
def testBreaksIfPartitioningChanges(self):
569569
with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner):
570570
tf.get_variable("name0", shape=(3, 1, 1))
571571

572-
with tf.variable_scope("scope0", partitioner=axis0_into3_partitioner):
572+
with tf.variable_scope("scope0",
573+
partitioner=axis0_into3_partitioner,
574+
reuse=True):
573575
with self.assertRaisesRegexp(
574-
ValueError, "Partitioner returned a different partitioning"):
576+
ValueError,
577+
"Trying to reuse partitioned variable .* but specified partitions .* "
578+
"and found partitions .*"):
575579
tf.get_variable("name0", shape=(3, 1, 1))
576580

577-
with tf.variable_scope("scope0", partitioner=axis0_into1_partitioner):
581+
with tf.variable_scope("scope0",
582+
partitioner=axis0_into1_partitioner,
583+
reuse=True):
578584
with self.assertRaisesRegexp(
579-
ValueError, "Partitioner returned a different partitioning"):
585+
ValueError,
586+
"Trying to reuse partitioned variable .* but specified partitions .* "
587+
"and found partitions .*"):
580588
tf.get_variable("name0", shape=(3, 1, 1))
581589

582590
def testReturnsExistingConcatenatedValueIfReuse(self):
@@ -586,6 +594,14 @@ def testReturnsExistingConcatenatedValueIfReuse(self):
586594
v_concat_2 = tf.get_variable("name0", shape=(3, 1, 1))
587595
self.assertEqual(v_concat, v_concat_2)
588596

597+
def testAllowsReuseWithoutPartitioner(self):
598+
with tf.variable_scope("scope0", partitioner=axis0_into2_partitioner):
599+
v = tf.get_variable("name0", shape=(3, 1, 1))
600+
with tf.variable_scope("scope0", reuse=True):
601+
v_reused = tf.get_variable("name0")
602+
603+
self.assertEqual(v, v_reused)
604+
589605
def testPartitionConcatenatesAlongCorrectAxis(self):
590606
def _part_axis_0(**unused_kwargs):
591607
return (2, 1, 1)
@@ -600,13 +616,13 @@ def _part_axis_1(**unused_kwargs):
600616
self.assertEqual(v0.get_shape(), (2, 2, 2))
601617
self.assertEqual(v1.get_shape(), (2, 2, 2))
602618

603-
n0_0 = tf.get_default_graph().get_tensor_by_name("root/n0_0:0")
604-
n0_1 = tf.get_default_graph().get_tensor_by_name("root/n0_1:0")
619+
n0_0 = tf.get_default_graph().get_tensor_by_name("root/n0/part_0:0")
620+
n0_1 = tf.get_default_graph().get_tensor_by_name("root/n0/part_1:0")
605621
self.assertEqual(n0_0.get_shape(), (1, 2, 2))
606622
self.assertEqual(n0_1.get_shape(), (1, 2, 2))
607623

608-
n1_0 = tf.get_default_graph().get_tensor_by_name("root/n1_0:0")
609-
n1_1 = tf.get_default_graph().get_tensor_by_name("root/n1_1:0")
624+
n1_0 = tf.get_default_graph().get_tensor_by_name("root/n1/part_0:0")
625+
n1_1 = tf.get_default_graph().get_tensor_by_name("root/n1/part_1:0")
610626
self.assertEqual(n1_0.get_shape(), (2, 1, 2))
611627
self.assertEqual(n1_1.get_shape(), (2, 1, 2))
612628

tensorflow/python/ops/partitioned_variables.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -210,20 +210,15 @@ def create_partitioned_variables(
210210
partitioner = lambda **unused_kwargs: slicing
211211

212212
with variable_scope.variable_op_scope(
213-
[], name, "PartitionedVariable", reuse=reuse) as scope:
214-
213+
[], name, "PartitionedVariable", reuse=reuse):
215214
# pylint: disable=protected-access
216-
vs, _ = variable_scope._get_partitioned_variable_list(
217-
name="part",
215+
partitioned_var = variable_scope._get_partitioned_variable(
216+
name=None,
218217
shape=shape,
219218
dtype=dtype,
220219
initializer=initializer,
221220
trainable=trainable,
222221
partitioner=partitioner,
223222
collections=collections)
224-
225-
for var in vs:
226-
var._save_slice_info.full_name = scope.name
223+
return partitioned_var._get_variable_list()
227224
# pylint: enable=protected-access
228-
229-
return vs

0 commit comments

Comments
 (0)