88#include < test/cpp/api/support.h>
99
1010#include < c10/util/ArrayRef.h>
11+ #include < c10/util/tempfile.h>
1112
1213#include < algorithm>
1314#include < chrono>
@@ -98,7 +99,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
9899 samplers::SequentialSampler sampler (0 );
99100
100101 auto initialization_function =
101- [&](size_t preloader_count, size_t batch_size, size_t cache_size) {
102+ [&](size_t preloader_count,
103+ size_t batch_size,
104+ size_t cache_size) {
102105 datasets::SharedBatchDataset<datasets::ChunkDataset<
103106 DummyChunkDataReader,
104107 samplers::SequentialSampler,
@@ -111,7 +114,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
111114 sampler,
112115 sampler,
113116 datasets::ChunkDatasetOptions (
114- preloader_count, batch_size, cache_size));
117+ preloader_count,
118+ batch_size,
119+ cache_size));
115120 };
116121
117122 ASSERT_THROWS_WITH (
@@ -1465,6 +1470,8 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
14651470 void reset () override {
14661471 counter = 0 ;
14671472 }
1473+ void save (torch::serialize::OutputArchive& archive) const override {};
1474+ void load (torch::serialize::InputArchive& archive) override {}
14681475 int counter = 0 ;
14691476 };
14701477
@@ -1501,6 +1508,8 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
15011508 void reset () override {
15021509 counter = 0 ;
15031510 }
1511+ void save (torch::serialize::OutputArchive& archive) const override {};
1512+ void load (torch::serialize::InputArchive& archive) override {}
15041513 int counter = 0 ;
15051514 std::mutex mutex;
15061515 };
@@ -1538,6 +1547,8 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
15381547 void reset () override {
15391548 counter = 0 ;
15401549 }
1550+ void save (torch::serialize::OutputArchive& archive) const override {};
1551+ void load (torch::serialize::InputArchive& archive) override {}
15411552 int counter = 0 ;
15421553 };
15431554
@@ -1585,6 +1596,8 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
15851596 void reset () override {
15861597 counter = 0 ;
15871598 }
1599+ void save (torch::serialize::OutputArchive& archive) const override {};
1600+ void load (torch::serialize::InputArchive& archive) override {}
15881601 int counter = 0 ;
15891602 };
15901603
@@ -1880,4 +1893,203 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
18801893 // to fill the batch buffer but it is not draining. Still we need to exit
18811894 // cleanly.
18821895 auto iterator = data_loader->begin ();
1896+ }
1897+
1898+ // Test ChunkDataset save function.
1899+ // Note [save/load ChunkDataset as ChunkSampler]:
1900+ // The chunk sampler inside ChunkDataset is used in a separate thread pool other
1901+ // than the main thread. Thus it is very hard to accurately estimate its status
1902+ // when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
1903+ // testing, we utilize the implementation fact that the file format for sampler
1904+ // serialization is the same as ChunkDataset serialization, and manually control
1905+ // the chunk sampler by calling the sampler's save/load method for value
1906+ // validation. This is only for testing the specific save/load functionality. In
1907+ // real user case, the user should still use matching ChunkDataset::save and
1908+ // ChunkDataset::load method.
1909+ TEST (DataLoaderTest, ChunkDatasetSave) {
1910+ const size_t chunk_count_ = 6 ;
1911+ const size_t chunk_size = 10 ;
1912+
1913+ struct DummyTestChunkDataReader : datasets::ChunkDataReader<int > {
1914+ public:
1915+ using BatchType = datasets::ChunkDataReader<int >::ChunkType;
1916+
1917+ BatchType read_chunk (size_t chunk_index) override {
1918+ return batch_data_;
1919+ }
1920+
1921+ size_t chunk_count () override {
1922+ return chunk_count_;
1923+ };
1924+
1925+ void reset () override {};
1926+ BatchType batch_data_ = BatchType(chunk_size, 0 );
1927+ };
1928+
1929+ const size_t prefetch_count = 1 ;
1930+ const size_t batch_size = chunk_size;
1931+ const size_t dataloader_worker_count = 0 ;
1932+ samplers::SequentialSampler sampler (0 );
1933+ const int epoch_count = 2 ;
1934+
1935+ DummyTestChunkDataReader data_reader;
1936+
1937+ // tested save_intervals
1938+ const size_t save_intervals[] = {1 , 2 };
1939+
1940+ using datasets::ChunkDatasetOptions;
1941+
1942+ for (auto save_interval : save_intervals) {
1943+ auto tempfile = c10::make_tempfile ();
1944+
1945+ datasets::SharedBatchDataset<datasets::ChunkDataset<
1946+ DummyTestChunkDataReader,
1947+ samplers::SequentialSampler,
1948+ samplers::SequentialSampler>>
1949+ dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
1950+ DummyTestChunkDataReader,
1951+ samplers::SequentialSampler,
1952+ samplers::SequentialSampler>>(
1953+ data_reader,
1954+ sampler,
1955+ sampler,
1956+ ChunkDatasetOptions (
1957+ prefetch_count, batch_size, chunk_size /* cache size*/ ));
1958+
1959+ auto data_loader = torch::data::make_data_loader (
1960+ dataset,
1961+ DataLoaderOptions (batch_size).workers (dataloader_worker_count));
1962+
1963+ for (int epoch_index = 0 ; epoch_index < epoch_count; ++epoch_index) {
1964+ int iteration_count = 0 ;
1965+ for (auto iterator = data_loader->begin (); iterator != data_loader->end ();
1966+ ++iterator, ++iteration_count) {
1967+ if ((iteration_count + 1 ) % save_interval == 0 ) {
1968+ torch::save (*dataset, tempfile.name );
1969+
1970+ samplers::SequentialSampler new_sampler (0 );
1971+
1972+ // See Note [save/load ChunkDataset as ChunkSampler]
1973+ torch::load (new_sampler, tempfile.name );
1974+
1975+ // Verify save logic. For ChunkDataset, the chunk data is stored in a
1976+ // cache inside the dataset. One pool of threads are constantly
1977+ // writing to the cache, and a different pool of thread are constantly
1978+ // reading from the cache. Due to the nature of asynchronization, at
1979+ // the time of get_batch(), which chunk is written to the cache is not
1980+ // fully deterministic.
1981+ // But we can still calculate a restricted window on the expected
1982+ // output, hence verify the logic. In this test, the cache size is
1983+ // configured to be the same as chunk size and batch size. So the
1984+ // chunk data is written to the cache one by one. Only the current
1985+ // batch is retrieved, the next chunk is writen. Now in iteration 0,
1986+ // after the first batch is retrieved, when we save the dataset
1987+ // statues, there are three possible scenarios for the writer thread:
1988+ // 1. it hasn't started loading the next chunk data yet, so the
1989+ // sequential sampler index is still 0;
1990+ // 2. it started to load the second chunk, so the sequencial sampler
1991+ // index is at 1;
1992+ // 3. it finished loading the second chunk, and start to load the
1993+ // third chunk, because the cache is still fully occupied by the data
1994+ // from the second chunk, it is waiting to write to the cache. At this
1995+ // point, the sampler index is at 2.
1996+ // So now we have a window of [0, 2], which is what we expected the
1997+ // sampler to save the index from. Now noted for sequential sampler,
1998+ // it advances to the next index automatically in the call next(). So
1999+ // when save the index, it saves the next index in stead of the
2000+ // current one. In other word, after getting the first index from
2001+ // sequential sampler, it already moves to the second index. So when
2002+ // we save it, it is the second index we save. As a result,
2003+ // we need to advance the window by one. Now we have the expected
2004+ // window of [1, 3].
2005+ // This analysis applies to all scenarios. So extend it to a more
2006+ // general case: the expected saved index should falling into the
2007+ // range of [iteration, iteration + 3], which is the validation
2008+ // below.
2009+ ASSERT_TRUE (
2010+ new_sampler.index () >= iteration_count + 1 &&
2011+ new_sampler.index () <= iteration_count + 3 );
2012+ }
2013+ }
2014+ }
2015+ }
2016+ }
2017+
2018+ // Test ChunkDataset load function.
2019+ TEST (DataLoaderTest, ChunkDatasetLoad) {
2020+ auto tempfile = c10::make_tempfile ();
2021+
2022+ const size_t prefetch_count = 1 ;
2023+ const size_t batch_size = 10 ;
2024+ const size_t dataloader_worker_count = 0 ;
2025+ const size_t save_interval = 2 ;
2026+
2027+ DummyChunkDataReader data_reader;
2028+ samplers::SequentialSampler sampler (0 );
2029+
2030+ const size_t skipped_chunk = 2 ;
2031+
2032+ // Configure sampler to skip 2 chunks
2033+ {
2034+ sampler.reset (data_reader.chunk_count ());
2035+ sampler.next (skipped_chunk);
2036+
2037+ // See Note [save/load ChunkDataset as ChunkSampler]
2038+ torch::save (sampler, tempfile.name );
2039+ }
2040+
2041+ // test functionality across epoch boundary. The first epoch should be
2042+ // affected by the checkpoint, but the second should start normally.
2043+ const int epoch_count = 2 ;
2044+
2045+ datasets::SharedBatchDataset<datasets::ChunkDataset<
2046+ DummyChunkDataReader,
2047+ samplers::SequentialSampler,
2048+ samplers::SequentialSampler>>
2049+ dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2050+ DummyChunkDataReader,
2051+ samplers::SequentialSampler,
2052+ samplers::SequentialSampler>>(
2053+ data_reader,
2054+ sampler,
2055+ sampler,
2056+ datasets::ChunkDatasetOptions (
2057+ prefetch_count, batch_size, 20 /* cache size*/ ));
2058+
2059+ torch::load (*dataset, tempfile.name );
2060+
2061+ auto data_loader = torch::data::make_data_loader (
2062+ dataset, DataLoaderOptions (batch_size).workers (dataloader_worker_count));
2063+
2064+ for (int epoch_index = 0 ; epoch_index < epoch_count; ++epoch_index) {
2065+ int iteration_count = 0 ;
2066+
2067+ // For the first epoch, the returned batch should be returned from the
2068+ // third chunk, because the check point skipped the first two chunks. But
2069+ // for the next epoch, it should start from the first batch.
2070+ int initial_value = epoch_index == 0 ? 15 : 0 ;
2071+
2072+ for (auto iterator = data_loader->begin (); iterator != data_loader->end ();
2073+ ++iterator, ++iteration_count) {
2074+ DummyChunkDataReader::BatchType batch = *iterator;
2075+
2076+ std::vector<int > expected_result;
2077+ size_t expected_size = (epoch_index > 0 && iteration_count == 3 ) ? 5 : 10 ;
2078+ expected_result.resize (expected_size);
2079+ std::iota (expected_result.begin (), expected_result.end (), initial_value);
2080+
2081+ ASSERT_EQ (batch.size (), expected_result.size ());
2082+ ASSERT_TRUE (
2083+ std::equal (batch.begin (), batch.end (), expected_result.begin ()));
2084+
2085+ initial_value += batch_size;
2086+ }
2087+ }
2088+
2089+ samplers::SequentialSampler new_sampler (0 );
2090+
2091+ // See Note [save/load ChunkDataset as ChunkSampler]
2092+ torch::load (new_sampler, tempfile.name );
2093+
2094+ ASSERT_EQ (new_sampler.index (), skipped_chunk);
18832095}
0 commit comments