Skip to content

Commit fe954af

Browse files
Illia Polosukhintensorflower-gardener
authored andcommitted
Fixing sparse_ops_test string comparison failure. Fixing embedding_lookup_unique to support >2d params as per description.
Change: 137730923
1 parent c5ccfe7 commit fe954af

3 files changed

Lines changed: 36 additions & 11 deletions

File tree

tensorflow/contrib/layers/python/layers/embedding_ops.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -346,31 +346,34 @@ def embedding_lookup_unique(params, ids, name=None):
346346
"""Version of embedding_lookup that avoids duplicate lookups.
347347
348348
This can save communication in the case of repeated ids.
349-
Same interface as embedding_lookup.
349+
Same interface as embedding_lookup. Except it supports multi-dimensional `ids`
350+
which allows to not reshape input/output to fit gather.
350351
351352
Args:
352353
params: A list of tensors with the same shape and type, or a
353-
`PartitionedVariable`.
354+
`PartitionedVariable`. Shape `[index, d1, d2, ...]`.
354355
ids: A one-dimensional `Tensor` with type `int32` or `int64` containing
355-
the ids to be looked up in `params`.
356+
the ids to be looked up in `params`. Shape `[ids1, ids2, ...]`.
356357
name: A name for this operation (optional).
357358
358359
Returns:
359-
A `Tensor` with the same type as the tensors in `params`.
360+
A `Tensor` with the same type as the tensors in `params` and dimension of
361+
`[ids1, ids2, d1, d2, ...]`.
360362
361363
Raises:
362364
ValueError: If `params` is empty.
363365
"""
364366
with ops.name_scope(name, "EmbeddingLookupUnique", [params, ids]):
365-
params = ops.convert_to_tensor(params)
366367
ids = ops.convert_to_tensor(ids)
367368
shape = array_ops.shape(ids)
368369
ids_flat = array_ops.reshape(
369370
ids, math_ops.reduce_prod(shape, keep_dims=True))
370371
unique_ids, idx = array_ops.unique(ids_flat)
371372
unique_embeddings = embedding_ops.embedding_lookup(params, unique_ids)
372373
embeds_flat = array_ops.gather(unique_embeddings, idx)
373-
embed_shape = array_ops.concat(0, [shape, [-1]])
374+
embed_shape = array_ops.concat(
375+
0, [shape, array_ops.shape(unique_embeddings)[1:]])
374376
embeds = array_ops.reshape(embeds_flat, embed_shape)
375-
embeds.set_shape(ids.get_shape().concatenate(params.get_shape()[1:]))
377+
embeds.set_shape(ids.get_shape().concatenate(
378+
unique_embeddings.get_shape()[1:]))
376379
return embeds

tensorflow/contrib/layers/python/layers/embedding_ops_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,28 @@ def test_embedding_lookup_unique(self):
371371
self.assertEqual(embedded_np.shape, embedded_tf.shape)
372372
np.testing.assert_almost_equal(embedded_np, embedded_tf)
373373

374+
def test_embedding_lookup_unique_param3d(self):
375+
embeds = np.random.randn(5, 3, 3)
376+
idx = np.random.randint(0, 5, 10)
377+
idx2d = np.random.randint(0, 5, (10, 2))
378+
379+
with self.test_session():
380+
embedded_np = embeds[idx]
381+
embedded_np2d = embeds[idx2d]
382+
embedded_tf = tf.contrib.layers.embedding_lookup_unique(
383+
embeds, idx).eval()
384+
embedded_tf_lst = tf.contrib.layers.embedding_lookup_unique(
385+
[embeds], idx).eval()
386+
embedded_tf2d = tf.contrib.layers.embedding_lookup_unique(
387+
embeds, idx2d).eval()
388+
389+
self.assertEqual(embedded_np.shape, embedded_tf.shape)
390+
np.testing.assert_almost_equal(embedded_np, embedded_tf)
391+
self.assertEqual(embedded_np.shape, embedded_tf_lst.shape)
392+
np.testing.assert_almost_equal(embedded_np, embedded_tf_lst)
393+
self.assertEqual(embedded_np2d.shape, embedded_tf2d.shape)
394+
np.testing.assert_almost_equal(embedded_np2d, embedded_tf2d)
395+
374396

375397
if __name__ == "__main__":
376398
tf.test.main()

tensorflow/contrib/layers/python/ops/sparse_ops_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,25 +61,25 @@ def test_dense_to_sparse_tensor_1d_bool(self):
6161

6262
def test_dense_to_sparse_tensor_1d_str(self):
6363
with self.test_session() as sess:
64-
st = sparse_ops.dense_to_sparse_tensor(['qwe', '', 'ewq', ''])
64+
st = sparse_ops.dense_to_sparse_tensor([b'qwe', b'', b'ewq', b''])
6565
result = sess.run(st)
6666
self.assertEqual(result.indices.dtype, np.int64)
6767
self.assertEqual(result.values.dtype, np.object)
6868
self.assertEqual(result.shape.dtype, np.int64)
6969
self.assertAllEqual([[0], [2]], result.indices)
70-
self.assertAllEqual(['qwe', 'ewq'], result.values)
70+
self.assertAllEqual([b'qwe', b'ewq'], result.values)
7171
self.assertAllEqual([4], result.shape)
7272

7373
def test_dense_to_sparse_tensor_1d_str_special_ignore(self):
7474
with self.test_session() as sess:
7575
st = sparse_ops.dense_to_sparse_tensor(
76-
['qwe', '', 'ewq', ''], ignore_value='qwe')
76+
[b'qwe', b'', b'ewq', b''], ignore_value=b'qwe')
7777
result = sess.run(st)
7878
self.assertEqual(result.indices.dtype, np.int64)
7979
self.assertEqual(result.values.dtype, np.object)
8080
self.assertEqual(result.shape.dtype, np.int64)
8181
self.assertAllEqual([[1], [2], [3]], result.indices)
82-
self.assertAllEqual(['', 'ewq', ''], result.values)
82+
self.assertAllEqual([b'', b'ewq', b''], result.values)
8383
self.assertAllEqual([4], result.shape)
8484

8585
def test_dense_to_sparse_tensor_2d(self):

0 commit comments

Comments
 (0)