Skip to content

Commit f65c8ce

Browse files
authored
Merge pull request #6139 from gunan/cp
[CMake] Fix support for custom kernels in `tf.contrib.metrics`.
2 parents b8ee14f + 9ff4b1a commit f65c8ce

File tree

5 files changed

+10
-10
lines changed

5 files changed

+10
-10
lines changed

tensorflow/contrib/cmake/tf_core_ops.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contr
4646
GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc")
4747
GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc")
4848
GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc")
49+
GENERATE_CONTRIB_OP_LIBRARY(metrics_set "${tensorflow_source_dir}/tensorflow/contrib/metrics/ops/set_ops.cc")
4950
GENERATE_CONTRIB_OP_LIBRARY(word2vec "${tensorflow_source_dir}/tensorflow/models/embedding/word2vec_ops.cc")
5051

5152
########################################################

tensorflow/contrib/cmake/tf_python.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,8 @@ GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops"
504504
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_factorization_ops.py)
505505
GENERATE_PYTHON_OP_LIB("contrib_framework_variable_ops"
506506
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/framework/python/ops/gen_variable_ops.py)
507+
GENERATE_PYTHON_OP_LIB("contrib_metrics_set_ops"
508+
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/metrics/python/ops/gen_set_ops.py)
507509
GENERATE_PYTHON_OP_LIB("contrib_word2vec_ops"
508510
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/models/embedding/gen_word2vec.py
509511
SHAPE_FUNCTIONS_NOT_REQUIRED)

tensorflow/contrib/cmake/tf_tests.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
119119
"${tensorflow_source_dir}/tensorflow/python/training/*_test.py"
120120
"${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py"
121121
"${tensorflow_source_dir}/tensorflow/models/*_test.py"
122+
"${tensorflow_source_dir}/tensorflow/contrib/metrics/*_test.py"
122123
)
123124

124125
# exclude the onces we don't want

tensorflow/contrib/metrics/BUILD

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,7 @@ tf_gen_op_libs(
3535

3636
tf_gen_op_wrapper_py(
3737
name = "set_ops",
38-
hidden = [
39-
"DenseToDenseSetOperation",
40-
"DenseToSparseSetOperation",
41-
"SparseToSparseSetOperation",
42-
"SetSize",
43-
],
38+
out = "python/ops/gen_set_ops.py",
4439
deps = [":set_ops_op_lib"],
4540
)
4641

tensorflow/contrib/metrics/python/ops/set_ops.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from tensorflow.contrib.framework.python.framework import tensor_util
2121

22+
from tensorflow.contrib.metrics.python.ops import gen_set_ops
2223
from tensorflow.contrib.util import loader
2324
from tensorflow.python.framework import dtypes
2425
from tensorflow.python.framework import ops
@@ -56,7 +57,7 @@ def set_size(a, validate_indices=True):
5657
if a.values.dtype.base_dtype not in _VALID_DTYPES:
5758
raise TypeError("Invalid dtype %s." % a.values.dtype)
5859
# pylint: disable=protected-access
59-
return _set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
60+
return gen_set_ops.set_size(a.indices, a.values, a.shape, validate_indices)
6061

6162
ops.NotDifferentiable("SetSize")
6263

@@ -100,17 +101,17 @@ def _set_operation(a, b, set_operation, validate_indices=True):
100101
# pylint: disable=protected-access
101102
if isinstance(a, sparse_tensor.SparseTensor):
102103
if isinstance(b, sparse_tensor.SparseTensor):
103-
indices, values, shape = _set_ops.sparse_to_sparse_set_operation(
104+
indices, values, shape = gen_set_ops.sparse_to_sparse_set_operation(
104105
a.indices, a.values, a.shape, b.indices, b.values, b.shape,
105106
set_operation, validate_indices)
106107
else:
107108
raise ValueError("Sparse,Dense is not supported, but Dense,Sparse is. "
108109
"Please flip the order of your inputs.")
109110
elif isinstance(b, sparse_tensor.SparseTensor):
110-
indices, values, shape = _set_ops.dense_to_sparse_set_operation(
111+
indices, values, shape = gen_set_ops.dense_to_sparse_set_operation(
111112
a, b.indices, b.values, b.shape, set_operation, validate_indices)
112113
else:
113-
indices, values, shape = _set_ops.dense_to_dense_set_operation(
114+
indices, values, shape = gen_set_ops.dense_to_dense_set_operation(
114115
a, b, set_operation, validate_indices)
115116
# pylint: enable=protected-access
116117
return sparse_tensor.SparseTensor(indices, values, shape)

0 commit comments

Comments
 (0)