Skip to content

Commit 0bfd671

Browse files
Change Saver to support checkpointing ops other than Variables.
Operation save and restore is encapsulated in a SaveableObject class, which ops can implement to allow checkpointing. Add this support to MutableHashtable. Add an op for hash table that atomically clears the table and imports the data. Add a test for reshaping to saver_test.py Change: 129900685
1 parent 266c455 commit 0bfd671

13 files changed

Lines changed: 612 additions & 169 deletions

File tree

tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,18 +67,18 @@ def __init__(self,
6767
value_dtype,
6868
default_value,
6969
num_shards=1,
70-
name=None):
70+
name='ShardedMutableHashTable'):
7171
with ops.name_scope(name, 'sharded_mutable_hash_table') as scope:
7272
super(_ShardedMutableHashTable, self).__init__(key_dtype, value_dtype,
7373
scope)
7474
table_shards = []
75-
for _ in range(num_shards):
75+
for i in range(num_shards):
7676
# TODO(andreasst): add placement hints once bug 30002625 is fixed.
7777
table_shards.append(lookup_ops.MutableHashTable(
7878
key_dtype=key_dtype,
7979
value_dtype=value_dtype,
8080
default_value=default_value,
81-
name=name))
81+
name='%s-%d' % (name, i)))
8282
self._table_shards = table_shards
8383
# TODO(andreasst): add a value_shape() method to LookupInterface
8484
# pylint: disable=protected-access

tensorflow/contrib/lookup/lookup_ops.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tensorflow.python.ops import array_ops
2525
from tensorflow.python.ops import gen_data_flow_ops
2626
from tensorflow.python.ops import math_ops
27+
from tensorflow.python.training.saver import BaseSaverBuilder
2728

2829

2930
class LookupInterface(object):
@@ -710,7 +711,8 @@ def __init__(self,
710711
value_dtype,
711712
default_value,
712713
shared_name=None,
713-
name=None):
714+
name="MutableHashTable",
715+
checkpoint=True):
714716
"""Creates an empty `MutableHashTable` object.
715717
716718
Creates a table, the type of its keys and values are specified by key_dtype
@@ -723,9 +725,14 @@ def __init__(self,
723725
shared_name: If non-empty, this table will be shared under
724726
the given name across multiple sessions.
725727
name: A name for the operation (optional).
728+
checkpoint: if True, the contents of the table are saved to and restored
729+
from checkpoints.
726730
727731
Returns:
728732
A `MutableHashTable` object.
733+
734+
Raises:
735+
ValueError: If checkpoint is True and no name was specified.
729736
"""
730737
self._default_value = ops.convert_to_tensor(default_value,
731738
dtype=value_dtype)
@@ -746,11 +753,14 @@ def __init__(self,
746753
value_shape=self._default_value.get_shape(),
747754
name=name)
748755
# pylint: enable=protected-access
749-
750756
super(MutableHashTable, self).__init__(key_dtype, value_dtype,
751757
self._table_ref.op.name.split(
752758
"/")[-1])
753759

760+
if checkpoint:
761+
saveable = MutableHashTable.MutableHashTableSaveable(self, name)
762+
ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
763+
754764
def size(self, name=None):
755765
"""Compute the number of elements in this table.
756766
@@ -849,3 +859,21 @@ def export(self, name=None):
849859
exported_values.set_shape(exported_keys.get_shape().concatenate(
850860
self._value_shape))
851861
return exported_keys, exported_values
862+
863+
class MutableHashTableSaveable(BaseSaverBuilder.SaveableObject):
864+
"""SaveableObject implementation for MutableHashTable."""
865+
866+
def __init__(self, table, name):
867+
tensors = table.export()
868+
specs = [
869+
BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
870+
BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
871+
]
872+
super(MutableHashTable.MutableHashTableSaveable, self).__init__(table,
873+
specs,
874+
name)
875+
876+
def restore(self, restored_tensors, unused_restored_shapes):
877+
# pylint: disable=protected-access
878+
return gen_data_flow_ops._lookup_table_import(
879+
self.op._table_ref, restored_tensors[0], restored_tensors[1])

tensorflow/contrib/lookup/lookup_ops_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import os
2121
import numpy as np
22+
import six
2223
import tensorflow as tf
2324

2425

@@ -267,6 +268,59 @@ def testMutableHashTable(self):
267268
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
268269
self.assertAllEqual([0, 1, 2], sorted_values)
269270

271+
def testSaveRestore(self):
272+
save_path = os.path.join(self.get_temp_dir(), "hash")
273+
274+
with self.test_session(graph=tf.Graph()) as sess:
275+
v0 = tf.Variable(10.0, name="v0")
276+
v1 = tf.Variable(20.0, name="v1")
277+
278+
default_val = -1
279+
keys = tf.constant(["b", "c", "d"], tf.string)
280+
values = tf.constant([0, 1, 2], tf.int64)
281+
table = tf.contrib.lookup.MutableHashTable(
282+
tf.string, tf.int64, default_val, name="t1", checkpoint=True)
283+
284+
save = tf.train.Saver()
285+
tf.initialize_all_variables().run()
286+
287+
# Check that the parameter nodes have been initialized.
288+
self.assertEqual(10.0, v0.eval())
289+
self.assertEqual(20.0, v1.eval())
290+
291+
self.assertAllEqual(0, table.size().eval())
292+
table.insert(keys, values).run()
293+
self.assertAllEqual(3, table.size().eval())
294+
295+
val = save.save(sess, save_path)
296+
self.assertTrue(isinstance(val, six.string_types))
297+
self.assertEqual(save_path, val)
298+
299+
with self.test_session(graph=tf.Graph()) as sess:
300+
v0 = tf.Variable(-1.0, name="v0")
301+
v1 = tf.Variable(-1.0, name="v1")
302+
default_val = -1
303+
table = tf.contrib.lookup.MutableHashTable(
304+
tf.string, tf.int64, default_val, name="t1", checkpoint=True)
305+
table.insert(
306+
tf.constant(["a", "c"], tf.string),
307+
tf.constant([12, 24], tf.int64)).run()
308+
self.assertAllEqual(2, table.size().eval())
309+
310+
save = tf.train.Saver()
311+
312+
# Restore the saved values in the parameter nodes.
313+
save.restore(sess, save_path)
314+
# Check that the parameter nodes have been restored.
315+
self.assertEqual(10.0, v0.eval())
316+
self.assertEqual(20.0, v1.eval())
317+
318+
self.assertAllEqual(3, table.size().eval())
319+
320+
input_string = tf.constant(["a", "b", "c", "d", "e"], tf.string)
321+
output = table.lookup(input_string)
322+
self.assertAllEqual([-1, 0, 1, 2, -1], output.eval())
323+
270324
def testMutableHashTableOfTensors(self):
271325
with self.test_session():
272326
default_val = tf.constant([-1, -1], tf.int64)

tensorflow/core/framework/lookup_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class LookupInterface : public ResourceBase {
6868

6969
virtual Status ExportValues(OpKernelContext* context) = 0;
7070

71+
virtual Status ImportValues(const Tensor& keys, const Tensor& values) = 0;
72+
7173
// Returns the data type of the key.
7274
virtual DataType key_dtype() const = 0;
7375

tensorflow/core/kernels/initializable_lookup_table.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ class InitializableLookupTable : public LookupInterface {
5656
"implementations");
5757
}
5858

59+
Status ImportValues(const Tensor& keys, const Tensor& values) final {
60+
return errors::Unimplemented(
61+
"ImportValues not supported by InitializableLookupTable "
62+
"implementations");
63+
}
64+
5965
TensorShape value_shape() const final { return TensorShape(); }
6066

6167
// Returns whether the table was initialized and is ready to serve lookups.

tensorflow/core/kernels/lookup_table_op.cc

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class HashTable : public InitializableLookupTable {
154154
// table.Find(in_t, &out_t, default_t)
155155
//
156156
template <class K, class V>
157-
class MutableHashTableOfScalars : public LookupInterface {
157+
class MutableHashTableOfScalars final : public LookupInterface {
158158
public:
159159
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
160160

@@ -179,11 +179,14 @@ class MutableHashTableOfScalars : public LookupInterface {
179179
return Status::OK();
180180
}
181181

182-
Status Insert(const Tensor& keys, const Tensor& values) override {
182+
Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
183183
const auto key_values = keys.flat<K>();
184184
const auto value_values = values.flat<V>();
185185

186186
mutex_lock l(mu_);
187+
if (clear) {
188+
table_.clear();
189+
}
187190
for (int64 i = 0; i < key_values.size(); ++i) {
188191
const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
189192
const V value = SubtleMustCopyUnlessStringOrFloat(value_values(i));
@@ -192,6 +195,14 @@ class MutableHashTableOfScalars : public LookupInterface {
192195
return Status::OK();
193196
}
194197

198+
Status Insert(const Tensor& keys, const Tensor& values) override {
199+
return DoInsert(false, keys, values);
200+
}
201+
202+
Status ImportValues(const Tensor& keys, const Tensor& values) override {
203+
return DoInsert(true, keys, values);
204+
}
205+
195206
Status ExportValues(OpKernelContext* ctx) override {
196207
mutex_lock l(mu_);
197208
int64 size = table_.size();
@@ -228,7 +239,7 @@ class MutableHashTableOfScalars : public LookupInterface {
228239
// Lookup table that wraps an unordered_map. Behaves identical to
229240
// MutableHashTableOfScalars except that each value must be a vector.
230241
template <class K, class V>
231-
class MutableHashTableOfTensors : public LookupInterface {
242+
class MutableHashTableOfTensors final : public LookupInterface {
232243
public:
233244
MutableHashTableOfTensors(OpKernelContext* ctx, OpKernel* kernel) {
234245
OP_REQUIRES_OK(ctx,
@@ -269,12 +280,15 @@ class MutableHashTableOfTensors : public LookupInterface {
269280
return Status::OK();
270281
}
271282

272-
Status Insert(const Tensor& keys, const Tensor& values) override {
283+
Status DoInsert(bool clear, const Tensor& keys, const Tensor& values) {
273284
const auto key_values = keys.flat<K>();
274285
const auto value_values = values.flat_inner_dims<V, 2>();
275286
int64 value_dim = value_shape_.dim_size(0);
276287

277288
mutex_lock l(mu_);
289+
if (clear) {
290+
table_.clear();
291+
}
278292
for (int64 i = 0; i < key_values.size(); ++i) {
279293
const K key = SubtleMustCopyUnlessStringOrFloat(key_values(i));
280294
ValueArray value_vec;
@@ -287,6 +301,14 @@ class MutableHashTableOfTensors : public LookupInterface {
287301
return Status::OK();
288302
}
289303

304+
Status Insert(const Tensor& keys, const Tensor& values) override {
305+
return DoInsert(false, keys, values);
306+
}
307+
308+
Status ImportValues(const Tensor& keys, const Tensor& values) override {
309+
return DoInsert(true, keys, values);
310+
}
311+
290312
Status ExportValues(OpKernelContext* ctx) override {
291313
mutex_lock l(mu_);
292314
int64 size = table_.size();
@@ -420,6 +442,30 @@ class LookupTableExportOp : public OpKernel {
420442
REGISTER_KERNEL_BUILDER(Name("LookupTableExport").Device(DEVICE_CPU),
421443
LookupTableExportOp);
422444

445+
// Clear the table and insert data.
446+
class LookupTableImportOp : public OpKernel {
447+
public:
448+
explicit LookupTableImportOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
449+
450+
void Compute(OpKernelContext* ctx) override {
451+
lookup::LookupInterface* table;
452+
OP_REQUIRES_OK(ctx, GetLookupTable("table_handle", ctx, &table));
453+
core::ScopedUnref unref_me(table);
454+
455+
DataTypeVector expected_inputs = {DT_STRING_REF, table->key_dtype(),
456+
table->value_dtype()};
457+
OP_REQUIRES_OK(ctx, ctx->MatchSignature(expected_inputs, {}));
458+
459+
const Tensor& keys = ctx->input(1);
460+
const Tensor& values = ctx->input(2);
461+
OP_REQUIRES_OK(ctx, table->CheckKeyAndValueTensors(keys, values));
462+
OP_REQUIRES_OK(ctx, table->ImportValues(keys, values));
463+
}
464+
};
465+
466+
REGISTER_KERNEL_BUILDER(Name("LookupTableImport").Device(DEVICE_CPU),
467+
LookupTableImportOp);
468+
423469
// Register the HashTable op with the currently supported key and value types.
424470
#define REGISTER_KERNEL(key_dtype, value_dtype) \
425471
REGISTER_KERNEL_BUILDER( \

tensorflow/core/ops/data_flow_ops.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,29 @@ keys: Vector of all keys present in the table.
11011101
values: Tensor of all values in the table. Indexed in parallel with `keys`.
11021102
)doc");
11031103

1104+
REGISTER_OP("LookupTableImport")
1105+
.Input("table_handle: Ref(string)")
1106+
.Input("keys: Tin")
1107+
.Input("values: Tout")
1108+
.Attr("Tin: type")
1109+
.Attr("Tout: type")
1110+
.SetShapeFn([](InferenceContext* c) {
1111+
const Shape* unused;
1112+
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
1113+
TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->input(2), &unused));
1114+
return Status::OK();
1115+
})
1116+
.Doc(R"doc(
1117+
Replaces the contents of the table with the specified keys and values.
1118+
1119+
The tensor `keys` must be of the same type as the keys of the table.
1120+
The tensor `values` must be of the type of the table values.
1121+
1122+
table_handle: Handle to the table.
1123+
keys: Any shape. Keys to look up.
1124+
values: Same shape as `keys`. Values to associate with keys.
1125+
)doc");
1126+
11041127
REGISTER_OP("HashTable")
11051128
.Output("table_handle: Ref(string)")
11061129
.Attr("container: string = ''")

tensorflow/python/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,7 @@ tf_gen_op_wrapper_py(
579579
"InitializeTableFromTextFile",
580580
"LookupTableExport",
581581
"LookupTableFind",
582+
"LookupTableImport",
582583
"LookupTableInsert",
583584
"LookupTableSize",
584585
"MutableHashTable",

tensorflow/python/framework/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3886,6 +3886,8 @@ class GraphKeys(object):
38863886
UPDATE_OPS = "update_ops"
38873887
# Key to collect losses
38883888
LOSSES = "losses"
3889+
# Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.
3890+
SAVEABLE_OBJECTS = "saveable_objects"
38893891

38903892
# Key to indicate various ops.
38913893
INIT_OP = "init_op"

tensorflow/python/ops/data_flow_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,7 @@ def _LookupTableFindShape(op):
11301130

11311131

11321132
@ops.RegisterShape("LookupTableInsert")
1133+
@ops.RegisterShape("LookupTableImport")
11331134
def _LookupTableInsertShape(op):
11341135
"""Shape function for data_flow_ops._lookup_table_insert."""
11351136
op.inputs[0].get_shape().merge_with(tensor_shape.scalar())

0 commit comments

Comments
 (0)