@@ -154,7 +154,7 @@ class HashTable : public InitializableLookupTable {
154154// table.Find(in_t, &out_t, default_t)
155155//
156156template <class K , class V >
157- class MutableHashTableOfScalars : public LookupInterface {
157+ class MutableHashTableOfScalars final : public LookupInterface {
158158 public:
159159 MutableHashTableOfScalars (OpKernelContext* ctx, OpKernel* kernel) {}
160160
@@ -179,11 +179,14 @@ class MutableHashTableOfScalars : public LookupInterface {
179179 return Status::OK ();
180180 }
181181
182- Status Insert ( const Tensor& keys, const Tensor& values) override {
182+ Status DoInsert ( bool clear, const Tensor& keys, const Tensor& values) {
183183 const auto key_values = keys.flat <K>();
184184 const auto value_values = values.flat <V>();
185185
186186 mutex_lock l (mu_);
187+ if (clear) {
188+ table_.clear ();
189+ }
187190 for (int64 i = 0 ; i < key_values.size (); ++i) {
188191 const K key = SubtleMustCopyUnlessStringOrFloat (key_values (i));
189192 const V value = SubtleMustCopyUnlessStringOrFloat (value_values (i));
@@ -192,6 +195,14 @@ class MutableHashTableOfScalars : public LookupInterface {
192195 return Status::OK ();
193196 }
194197
198+ Status Insert (const Tensor& keys, const Tensor& values) override {
199+ return DoInsert (false , keys, values);
200+ }
201+
202+ Status ImportValues (const Tensor& keys, const Tensor& values) override {
203+ return DoInsert (true , keys, values);
204+ }
205+
195206 Status ExportValues (OpKernelContext* ctx) override {
196207 mutex_lock l (mu_);
197208 int64 size = table_.size ();
@@ -228,7 +239,7 @@ class MutableHashTableOfScalars : public LookupInterface {
228239// Lookup table that wraps an unordered_map. Behaves identical to
229240// MutableHashTableOfScalars except that each value must be a vector.
230241template <class K , class V >
231- class MutableHashTableOfTensors : public LookupInterface {
242+ class MutableHashTableOfTensors final : public LookupInterface {
232243 public:
233244 MutableHashTableOfTensors (OpKernelContext* ctx, OpKernel* kernel) {
234245 OP_REQUIRES_OK (ctx,
@@ -269,12 +280,15 @@ class MutableHashTableOfTensors : public LookupInterface {
269280 return Status::OK ();
270281 }
271282
272- Status Insert ( const Tensor& keys, const Tensor& values) override {
283+ Status DoInsert ( bool clear, const Tensor& keys, const Tensor& values) {
273284 const auto key_values = keys.flat <K>();
274285 const auto value_values = values.flat_inner_dims <V, 2 >();
275286 int64 value_dim = value_shape_.dim_size (0 );
276287
277288 mutex_lock l (mu_);
289+ if (clear) {
290+ table_.clear ();
291+ }
278292 for (int64 i = 0 ; i < key_values.size (); ++i) {
279293 const K key = SubtleMustCopyUnlessStringOrFloat (key_values (i));
280294 ValueArray value_vec;
@@ -287,6 +301,14 @@ class MutableHashTableOfTensors : public LookupInterface {
287301 return Status::OK ();
288302 }
289303
304+ Status Insert (const Tensor& keys, const Tensor& values) override {
305+ return DoInsert (false , keys, values);
306+ }
307+
308+ Status ImportValues (const Tensor& keys, const Tensor& values) override {
309+ return DoInsert (true , keys, values);
310+ }
311+
290312 Status ExportValues (OpKernelContext* ctx) override {
291313 mutex_lock l (mu_);
292314 int64 size = table_.size ();
@@ -420,6 +442,30 @@ class LookupTableExportOp : public OpKernel {
420442REGISTER_KERNEL_BUILDER (Name(" LookupTableExport" ).Device(DEVICE_CPU ),
421443 LookupTableExportOp);
422444
445+ // Clear the table and insert data.
446+ class LookupTableImportOp : public OpKernel {
447+ public:
448+ explicit LookupTableImportOp (OpKernelConstruction* ctx) : OpKernel(ctx) {}
449+
450+ void Compute (OpKernelContext* ctx) override {
451+ lookup::LookupInterface* table;
452+ OP_REQUIRES_OK (ctx, GetLookupTable (" table_handle" , ctx, &table));
453+ core::ScopedUnref unref_me (table);
454+
455+ DataTypeVector expected_inputs = {DT_STRING_REF , table->key_dtype (),
456+ table->value_dtype ()};
457+ OP_REQUIRES_OK (ctx, ctx->MatchSignature (expected_inputs, {}));
458+
459+ const Tensor& keys = ctx->input (1 );
460+ const Tensor& values = ctx->input (2 );
461+ OP_REQUIRES_OK (ctx, table->CheckKeyAndValueTensors (keys, values));
462+ OP_REQUIRES_OK (ctx, table->ImportValues (keys, values));
463+ }
464+ };
465+
466+ REGISTER_KERNEL_BUILDER (Name(" LookupTableImport" ).Device(DEVICE_CPU ),
467+ LookupTableImportOp);
468+
423469// Register the HashTable op with the currently supported key and value types.
424470#define REGISTER_KERNEL (key_dtype, value_dtype ) \
425471 REGISTER_KERNEL_BUILDER ( \
0 commit comments