Skip to content

Commit f39b662

Browse files
xzhu1900facebook-github-bot
authored andcommitted
ChunkDataset checkpoint support (#21889)
Summary: When dealing with large scale dataset, it is handy if we can save the dataset status and resume later. Especially in cases where some unexpected crash happens, user don't need to start over the whole dataset from begining. Instead, they can reload it from the last checkpoint. This change adds support for checkpoint save/load logic in ChunkDataset. On ChunkDataset construction, user can specify a file name from which to load the checkpoint. If it is empty, default to start from fresh; otherwise the ChunkDataset will 'fast forward' the chunk sampler to the corresponding checkpoint. The user can also call ChunkDataset::save() to serialize current status to a file, which can be used later. Pull Request resolved: #21889 Differential Revision: D16024582 Pulled By: ailzhang fbshipit-source-id: 1862ab5116f94c9d29da174ce04a91041d06cad5
1 parent 30d890c commit f39b662

File tree

3 files changed

+264
-9
lines changed

3 files changed

+264
-9
lines changed

test/cpp/api/dataloader.cpp

Lines changed: 214 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <test/cpp/api/support.h>
99

1010
#include <c10/util/ArrayRef.h>
11+
#include <c10/util/tempfile.h>
1112

1213
#include <algorithm>
1314
#include <chrono>
@@ -98,7 +99,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
9899
samplers::SequentialSampler sampler(0);
99100

100101
auto initialization_function =
101-
[&](size_t preloader_count, size_t batch_size, size_t cache_size) {
102+
[&](size_t preloader_count,
103+
size_t batch_size,
104+
size_t cache_size) {
102105
datasets::SharedBatchDataset<datasets::ChunkDataset<
103106
DummyChunkDataReader,
104107
samplers::SequentialSampler,
@@ -111,7 +114,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
111114
sampler,
112115
sampler,
113116
datasets::ChunkDatasetOptions(
114-
preloader_count, batch_size, cache_size));
117+
preloader_count,
118+
batch_size,
119+
cache_size));
115120
};
116121

117122
ASSERT_THROWS_WITH(
@@ -1465,6 +1470,8 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
14651470
void reset() override {
14661471
counter = 0;
14671472
}
1473+
void save(torch::serialize::OutputArchive& archive) const override{};
1474+
void load(torch::serialize::InputArchive& archive) override {}
14681475
int counter = 0;
14691476
};
14701477

@@ -1501,6 +1508,8 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
15011508
void reset() override {
15021509
counter = 0;
15031510
}
1511+
void save(torch::serialize::OutputArchive& archive) const override{};
1512+
void load(torch::serialize::InputArchive& archive) override {}
15041513
int counter = 0;
15051514
std::mutex mutex;
15061515
};
@@ -1538,6 +1547,8 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
15381547
void reset() override {
15391548
counter = 0;
15401549
}
1550+
void save(torch::serialize::OutputArchive& archive) const override{};
1551+
void load(torch::serialize::InputArchive& archive) override {}
15411552
int counter = 0;
15421553
};
15431554

@@ -1585,6 +1596,8 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
15851596
void reset() override {
15861597
counter = 0;
15871598
}
1599+
void save(torch::serialize::OutputArchive& archive) const override{};
1600+
void load(torch::serialize::InputArchive& archive) override {}
15881601
int counter = 0;
15891602
};
15901603

@@ -1880,4 +1893,203 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
18801893
// to fill the batch buffer but it is not draining. Still we need to exit
18811894
// cleanly.
18821895
auto iterator = data_loader->begin();
1896+
}
1897+
1898+
// Test ChunkDataset save function.
1899+
// Note [save/load ChunkDataset as ChunkSampler]:
1900+
// The chunk sampler inside ChunkDataset is used in a separate thread pool other
1901+
// than the main thread. Thus it is very hard to accurately estimate its status
1902+
// when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1903+
// testing, we utilize the implementation fact that the file format for sampler
1904+
// serialization is the same as ChunkDataset serialization, and manually control
1905+
// the chunk sampler by calling the sampler's save/load method for value
1906+
// validation. This is only for testing the specific save/load functionality. In
1907+
// real user case, the user should still use matching ChunkDataset::save and
1908+
// ChunkDataset::load method.
1909+
TEST(DataLoaderTest, ChunkDatasetSave) {
1910+
const size_t chunk_count_ = 6;
1911+
const size_t chunk_size = 10;
1912+
1913+
struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
1914+
public:
1915+
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
1916+
1917+
BatchType read_chunk(size_t chunk_index) override {
1918+
return batch_data_;
1919+
}
1920+
1921+
size_t chunk_count() override {
1922+
return chunk_count_;
1923+
};
1924+
1925+
void reset() override{};
1926+
BatchType batch_data_ = BatchType(chunk_size, 0);
1927+
};
1928+
1929+
const size_t prefetch_count = 1;
1930+
const size_t batch_size = chunk_size;
1931+
const size_t dataloader_worker_count = 0;
1932+
samplers::SequentialSampler sampler(0);
1933+
const int epoch_count = 2;
1934+
1935+
DummyTestChunkDataReader data_reader;
1936+
1937+
// tested save_intervals
1938+
const size_t save_intervals[] = {1, 2};
1939+
1940+
using datasets::ChunkDatasetOptions;
1941+
1942+
for (auto save_interval : save_intervals) {
1943+
auto tempfile = c10::make_tempfile();
1944+
1945+
datasets::SharedBatchDataset<datasets::ChunkDataset<
1946+
DummyTestChunkDataReader,
1947+
samplers::SequentialSampler,
1948+
samplers::SequentialSampler>>
1949+
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1950+
DummyTestChunkDataReader,
1951+
samplers::SequentialSampler,
1952+
samplers::SequentialSampler>>(
1953+
data_reader,
1954+
sampler,
1955+
sampler,
1956+
ChunkDatasetOptions(
1957+
prefetch_count, batch_size, chunk_size /*cache size*/));
1958+
1959+
auto data_loader = torch::data::make_data_loader(
1960+
dataset,
1961+
DataLoaderOptions(batch_size).workers(dataloader_worker_count));
1962+
1963+
for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
1964+
int iteration_count = 0;
1965+
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
1966+
++iterator, ++iteration_count) {
1967+
if ((iteration_count + 1) % save_interval == 0) {
1968+
torch::save(*dataset, tempfile.name);
1969+
1970+
samplers::SequentialSampler new_sampler(0);
1971+
1972+
// See Note [save/load ChunkDataset as ChunkSampler]
1973+
torch::load(new_sampler, tempfile.name);
1974+
1975+
// Verify save logic. For ChunkDataset, the chunk data is stored in a
1976+
// cache inside the dataset. One pool of threads are constantly
1977+
// writing to the cache, and a different pool of thread are constantly
1978+
// reading from the cache. Due to the nature of asynchronization, at
1979+
// the time of get_batch(), which chunk is written to the cache is not
1980+
// fully deterministic.
1981+
// But we can still calculate a restricted window on the expected
1982+
// output, hence verify the logic. In this test, the cache size is
1983+
// configured to be the same as chunk size and batch size. So the
1984+
// chunk data is written to the cache one by one. Only the current
1985+
// batch is retrieved, the next chunk is writen. Now in iteration 0,
1986+
// after the first batch is retrieved, when we save the dataset
1987+
// statues, there are three possible scenarios for the writer thread:
1988+
// 1. it hasn't started loading the next chunk data yet, so the
1989+
// sequential sampler index is still 0;
1990+
// 2. it started to load the second chunk, so the sequencial sampler
1991+
// index is at 1;
1992+
// 3. it finished loading the second chunk, and start to load the
1993+
// third chunk, because the cache is still fully occupied by the data
1994+
// from the second chunk, it is waiting to write to the cache. At this
1995+
// point, the sampler index is at 2.
1996+
// So now we have a window of [0, 2], which is what we expected the
1997+
// sampler to save the index from. Now noted for sequential sampler,
1998+
// it advances to the next index automatically in the call next(). So
1999+
// when save the index, it saves the next index in stead of the
2000+
// current one. In other word, after getting the first index from
2001+
// sequential sampler, it already moves to the second index. So when
2002+
// we save it, it is the second index we save. As a result,
2003+
// we need to advance the window by one. Now we have the expected
2004+
// window of [1, 3].
2005+
// This analysis applies to all scenarios. So extend it to a more
2006+
// general case: the expected saved index should falling into the
2007+
// range of [iteration, iteration + 3], which is the validation
2008+
// below.
2009+
ASSERT_TRUE(
2010+
new_sampler.index() >= iteration_count + 1 &&
2011+
new_sampler.index() <= iteration_count + 3);
2012+
}
2013+
}
2014+
}
2015+
}
2016+
}
2017+
2018+
// Test ChunkDataset load function.
2019+
TEST(DataLoaderTest, ChunkDatasetLoad) {
2020+
auto tempfile = c10::make_tempfile();
2021+
2022+
const size_t prefetch_count = 1;
2023+
const size_t batch_size = 10;
2024+
const size_t dataloader_worker_count = 0;
2025+
const size_t save_interval = 2;
2026+
2027+
DummyChunkDataReader data_reader;
2028+
samplers::SequentialSampler sampler(0);
2029+
2030+
const size_t skipped_chunk = 2;
2031+
2032+
// Configure sampler to skip 2 chunks
2033+
{
2034+
sampler.reset(data_reader.chunk_count());
2035+
sampler.next(skipped_chunk);
2036+
2037+
// See Note [save/load ChunkDataset as ChunkSampler]
2038+
torch::save(sampler, tempfile.name);
2039+
}
2040+
2041+
// test functionality across epoch boundary. The first epoch should be
2042+
// affected by the checkpoint, but the second should start normally.
2043+
const int epoch_count = 2;
2044+
2045+
datasets::SharedBatchDataset<datasets::ChunkDataset<
2046+
DummyChunkDataReader,
2047+
samplers::SequentialSampler,
2048+
samplers::SequentialSampler>>
2049+
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2050+
DummyChunkDataReader,
2051+
samplers::SequentialSampler,
2052+
samplers::SequentialSampler>>(
2053+
data_reader,
2054+
sampler,
2055+
sampler,
2056+
datasets::ChunkDatasetOptions(
2057+
prefetch_count, batch_size, 20 /*cache size*/));
2058+
2059+
torch::load(*dataset, tempfile.name);
2060+
2061+
auto data_loader = torch::data::make_data_loader(
2062+
dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));
2063+
2064+
for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
2065+
int iteration_count = 0;
2066+
2067+
// For the first epoch, the returned batch should be returned from the
2068+
// third chunk, because the check point skipped the first two chunks. But
2069+
// for the next epoch, it should start from the first batch.
2070+
int initial_value = epoch_index == 0 ? 15 : 0;
2071+
2072+
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2073+
++iterator, ++iteration_count) {
2074+
DummyChunkDataReader::BatchType batch = *iterator;
2075+
2076+
std::vector<int> expected_result;
2077+
size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
2078+
expected_result.resize(expected_size);
2079+
std::iota(expected_result.begin(), expected_result.end(), initial_value);
2080+
2081+
ASSERT_EQ(batch.size(), expected_result.size());
2082+
ASSERT_TRUE(
2083+
std::equal(batch.begin(), batch.end(), expected_result.begin()));
2084+
2085+
initial_value += batch_size;
2086+
}
2087+
}
2088+
2089+
samplers::SequentialSampler new_sampler(0);
2090+
2091+
// See Note [save/load ChunkDataset as ChunkSampler]
2092+
torch::load(new_sampler, tempfile.name);
2093+
2094+
ASSERT_EQ(new_sampler.index(), skipped_chunk);
18832095
}

torch/csrc/api/include/torch/data/datasets/chunk.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
#include <queue>
88
#include <thread>
99

10+
#include <torch/serialize.h>
11+
1012
namespace torch {
1113
namespace data {
1214
namespace datasets {
@@ -270,7 +272,7 @@ struct ChunkDatasetOptions {
270272
/// The size of each batch.
271273
TORCH_ARG(size_t, batch_size);
272274

273-
// the capacity of the queue for batch caching.
275+
/// The capacity of the queue for batch caching.
274276
TORCH_ARG(size_t, cache_size) = 2048;
275277
};
276278

@@ -308,7 +310,8 @@ class ChunkDataset final
308310
example_sampler_(std::move(example_sampler)),
309311
options_(std::move(options)),
310312
quit_worker_(false),
311-
running_preloaders_(0) {}
313+
running_preloaders_(0),
314+
load_checkpoint_(false) {}
312315

313316
virtual ~ChunkDataset() {
314317
// stop batch buffer first.
@@ -332,7 +335,6 @@ class ChunkDataset final
332335
"The requested batch size does not match with the initialized batch size.\n"
333336
" The requested batch size is ", batch_size,
334337
", while the dataset is created with batch size equal to ", options_.batch_size_);
335-
336338
return batch_buffer_->get_batch();
337339
}
338340

@@ -352,9 +354,11 @@ class ChunkDataset final
352354
free_workers();
353355
preload_threads_.clear();
354356

355-
chunk_reader_.reset();
356-
357-
chunk_sampler_.reset(chunk_reader_.chunk_count());
357+
if (!load_checkpoint_){
358+
chunk_reader_.reset();
359+
chunk_sampler_.reset(chunk_reader_.chunk_count());
360+
load_checkpoint_ = false;
361+
}
358362

359363
// Throw out any existing cached batch in the buffer and re-creates a new
360364
// chunk buffer.
@@ -385,6 +389,17 @@ class ChunkDataset final
385389
return chunk_sampler_;
386390
}
387391

392+
void save(serialize::OutputArchive& archive) const override {
393+
std::lock_guard<std::mutex> lock(chunk_index_guard_);
394+
chunk_sampler_.save(archive);
395+
}
396+
397+
void load(serialize::InputArchive& archive) override{
398+
std::lock_guard<std::mutex> lock(chunk_index_guard_);
399+
chunk_sampler_.load(archive);
400+
load_checkpoint_ = true;
401+
}
402+
388403
private:
389404
/// running on worker thread to preload chunk data.
390405
void preloader(size_t id) {
@@ -455,7 +470,10 @@ class ChunkDataset final
455470
std::atomic<size_t> running_preloaders_;
456471

457472
// mutex to synchronize chunk sampler next() call.
458-
std::mutex chunk_index_guard_;
473+
mutable std::mutex chunk_index_guard_;
474+
475+
// boolean value to indicate whether we need to load the checkpoint for chunk_sampler_.
476+
bool load_checkpoint_;
459477
};
460478
} // namespace datasets
461479
} // namespace data

torch/csrc/api/include/torch/data/datasets/stateful.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,32 @@ class StatefulDataset
3030
public:
3131
/// Resets internal state of the dataset.
3232
virtual void reset() = 0;
33+
34+
/// Saves the statefulDataset's state to OutputArchive.
35+
virtual void save(serialize::OutputArchive& archive) const = 0;
36+
37+
/// Deserializes the statefulDataset's state from the `archive`.
38+
virtual void load(serialize::InputArchive& archive) = 0;
3339
};
40+
41+
/// Serializes a statefulDataset to `OutputArchive`.
42+
template <typename... Args>
43+
serialize::OutputArchive& operator<<(
44+
serialize::OutputArchive& archive,
45+
const StatefulDataset<Args...>& statefulDataset) {
46+
statefulDataset.save(archive);
47+
return archive;
48+
}
49+
50+
/// Deserializes a statefulDataset from an `InputArchive`.
51+
template <typename... Args>
52+
serialize::InputArchive& operator>>(
53+
serialize::InputArchive& archive,
54+
StatefulDataset<Args...>& statefulDataset) {
55+
statefulDataset.load(archive);
56+
return archive;
57+
}
58+
3459
} // namespace datasets
3560
} // namespace data
3661
} // namespace torch

0 commit comments

Comments
 (0)