Skip to content

Commit a6037e9

Browse files
mrryVijay Vasudevan
authored andcommitted
Add better shape inference for tf.zeros_like() and tf.ones_like().
Previously, partial shape information was discarded, because our constant evaluation for (e.g.) `tf.shape(tf.placeholder([..., None, ...]))` could not produce a Numpy array for the shape. Since the *_like wrappers have access to the input tensor, we can use `Tensor.set_shape()` to add back the partial information. Fixes tensorflow#744. Change: 111856452
1 parent 60bccf6 commit a6037e9

2 files changed

Lines changed: 16 additions & 2 deletions

File tree

tensorflow/python/kernel_tests/constant_op_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,11 @@ def testZerosLike(self):
321321
self.assertTrue(np.array_equal(z_value, np.array([[0] * 3] * 2)))
322322
self.assertEqual([2, 3], z_var.get_shape())
323323

324+
def testZerosLikePartialShape(self):
325+
d = tf.placeholder(tf.float32, shape=[None, 4, None])
326+
z = tf.zeros_like(d)
327+
self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list())
328+
324329
def testGenZerosLike(self):
325330
for dtype in [tf.float32, tf.float64, tf.int32,
326331
tf.uint8, tf.int16, tf.int8,
@@ -406,6 +411,11 @@ def testOnesLike(self):
406411
self.assertTrue(np.array_equal(z_value, np.array([[1] * 3] * 2)))
407412
self.assertEqual([2, 3], z_var.get_shape())
408413

414+
def testOnesLikePartialShape(self):
415+
d = tf.placeholder(tf.float32, shape=[None, 4, None])
416+
z = tf.zeros_like(d)
417+
self.assertEqual(d.get_shape().as_list(), z.get_shape().as_list())
418+
409419
def testGenOnesLike(self):
410420
for dtype in [tf.float32, tf.float64, tf.int32,
411421
tf.uint8, tf.int16, tf.int8,

tensorflow/python/ops/array_ops.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,9 @@ def zeros_like(tensor, dtype=None, name=None):
563563
zeros_shape = shape(tensor)
564564
if dtype is None:
565565
dtype = tensor.dtype
566-
return zeros(zeros_shape, dtype=dtype, name=name)
566+
ret = zeros(zeros_shape, dtype=dtype, name=name)
567+
ret.set_shape(tensor.get_shape())
568+
return ret
567569

568570

569571
def ones_like(tensor, dtype=None, name=None):
@@ -594,7 +596,9 @@ def ones_like(tensor, dtype=None, name=None):
594596
ones_shape = shape(tensor)
595597
if dtype is None:
596598
dtype = tensor.dtype
597-
return ones(ones_shape, dtype=dtype, name=name)
599+
ret = ones(ones_shape, dtype=dtype, name=name)
600+
ret.set_shape(tensor.get_shape())
601+
return ret
598602

599603

600604
def zeros_initializer(shape, dtype=dtypes.float32):

0 commit comments

Comments
 (0)