Skip to content

Commit a4dfb8d

Browse files
Merge pull request #49124 from tensorflow/mm-cherrypick-tf-data-segfault-fix-to-r2.5
[tf.data][cherrypick] Fix snapshot segfault when using repeat and pre…
2 parents 2107b1d + 16b8139 commit a4dfb8d

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

tensorflow/core/kernels/data/experimental/snapshot_dataset_op.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,6 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
201201

202202
explicit Reader(const Params& params, int64 start_index);
203203

204-
~Reader() override;
205-
206204
Status Initialize(IteratorContext* ctx) override;
207205

208206
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
@@ -222,7 +220,7 @@ class SnapshotDatasetV2Op::Dataset::Iterator::Reader
222220

223221
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
224222

225-
DatasetBase* input_ TF_GUARDED_BY(mu_);
223+
DatasetBase* input_ TF_GUARDED_BY(mu_) = nullptr;
226224

227225
std::unique_ptr<InstantiatedCapturedFunction> instantiated_reader_func_
228226
TF_GUARDED_BY(mu_);
@@ -468,7 +466,11 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::GetNextInternal(
468466
bool* end_of_sequence) {
469467
mutex_lock l(mu_);
470468
if (iterator_ == nullptr) {
471-
TF_RETURN_IF_ERROR(InitializeIterator(ctx, nullptr));
469+
Status s = InitializeIterator(ctx, nullptr);
470+
if (!s.ok()) {
471+
iterator_.reset();
472+
return s;
473+
}
472474
}
473475
index_++;
474476
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
@@ -547,8 +549,6 @@ SnapshotDatasetV2Op::Dataset::Iterator::Reader::Reader(const Params& params,
547549
int64 start_index)
548550
: DatasetIterator<Dataset>(params), start_index_(start_index) {}
549551

550-
SnapshotDatasetV2Op::Dataset::Iterator::Reader::~Reader() { input_->Unref(); }
551-
552552
Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
553553
IteratorContext* ctx) {
554554
mutex_lock l(mu_);
@@ -597,10 +597,6 @@ Status SnapshotDatasetV2Op::Dataset::Iterator::Reader::Initialize(
597597
}
598598
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(reader_output[0], &input_));
599599

600-
// We need to take a reference here as we will use the input_ and
601-
// its iterator.
602-
input_->Ref();
603-
604600
return input_->MakeIterator(ctx, this, prefix(), &input_impl_);
605601
}
606602

tensorflow/python/data/experimental/kernel_tests/snapshot_test.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,19 @@ def testReadOptimizableUsingFlatMap(self):
413413
num_runs_per_fingerprint=1,
414414
num_snapshot_shards_per_run=multiprocessing.cpu_count())
415415

416+
@combinations.generate(test_base.default_test_combinations())
417+
def testRepeatAndPrefetch(self):
418+
"""This test reproduces github.com/tensorflow/tensorflow/issues/48903."""
419+
dataset = dataset_ops.Dataset.from_tensor_slices(np.random.rand(16, 32))
420+
dataset = dataset.apply(snapshot.snapshot(self._snapshot_dir))
421+
dataset = dataset.shuffle(buffer_size=16)
422+
dataset = dataset.batch(16)
423+
dataset = dataset.repeat()
424+
dataset = dataset.prefetch(1)
425+
next_element = self.getNext(dataset)
426+
for _ in range(30):
427+
self.evaluate(next_element())
428+
416429

417430
class LegacySnapshotTest(tf_record_test_base.TFRecordTestBase,
418431
parameterized.TestCase):

0 commit comments

Comments
 (0)