Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
216 changes: 214 additions & 2 deletions test/cpp/api/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <test/cpp/api/support.h>

#include <c10/util/ArrayRef.h>
#include <c10/util/tempfile.h>

#include <algorithm>
#include <chrono>
Expand Down Expand Up @@ -98,7 +99,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
samplers::SequentialSampler sampler(0);

auto initialization_function =
[&](size_t preloader_count, size_t batch_size, size_t cache_size) {
[&](size_t preloader_count,
size_t batch_size,
size_t cache_size) {
datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
Expand All @@ -111,7 +114,9 @@ TEST(DataTest, ChunkDataSetWithInvalidInitParameter) {
sampler,
sampler,
datasets::ChunkDatasetOptions(
preloader_count, batch_size, cache_size));
preloader_count,
batch_size,
cache_size));
};

ASSERT_THROWS_WITH(
Expand Down Expand Up @@ -1465,6 +1470,8 @@ TEST(DataLoaderTest, StatefulDatasetWithNoWorkers) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};

Expand Down Expand Up @@ -1501,6 +1508,8 @@ TEST(DataLoaderTest, StatefulDatasetWithManyWorkers) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
std::mutex mutex;
};
Expand Down Expand Up @@ -1538,6 +1547,8 @@ TEST(DataLoaderTest, StatefulDatasetWithMap) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};

Expand Down Expand Up @@ -1585,6 +1596,8 @@ TEST(DataLoaderTest, StatefulDatasetWithCollate) {
void reset() override {
counter = 0;
}
void save(torch::serialize::OutputArchive& archive) const override{};
void load(torch::serialize::InputArchive& archive) override {}
int counter = 0;
};

Expand Down Expand Up @@ -1880,4 +1893,203 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
// to fill the batch buffer but it is not draining. Still we need to exit
// cleanly.
auto iterator = data_loader->begin();
}

// Test ChunkDataset save function.
// Note [save/load ChunkDataset as ChunkSampler]:
// The chunk sampler inside ChunkDataset is used in a separate thread pool other
// than the main thread. Thus it is very hard to accurately estimate its status
// when ChunkDataset::save/ChunkDataset::load is called. For the pure purpose of
// testing, we utilize the implementation fact that the file format for sampler
// serialization is the same as ChunkDataset serialization, and manually control
// the chunk sampler by calling the sampler's save/load method for value
// validation. This is only for testing the specific save/load functionality. In
// real user case, the user should still use matching ChunkDataset::save and
// ChunkDataset::load method.
TEST(DataLoaderTest, ChunkDatasetSave) {
const size_t chunk_count_ = 6;
const size_t chunk_size = 10;

struct DummyTestChunkDataReader : datasets::ChunkDataReader<int> {
public:
using BatchType = datasets::ChunkDataReader<int>::ChunkType;

BatchType read_chunk(size_t chunk_index) override {
return batch_data_;
}

size_t chunk_count() override {
return chunk_count_;
};

void reset() override{};
BatchType batch_data_ = BatchType(chunk_size, 0);
};

const size_t prefetch_count = 1;
const size_t batch_size = chunk_size;
const size_t dataloader_worker_count = 0;
samplers::SequentialSampler sampler(0);
const int epoch_count = 2;

DummyTestChunkDataReader data_reader;

// tested save_intervals
const size_t save_intervals[] = {1, 2};

using datasets::ChunkDatasetOptions;

for (auto save_interval : save_intervals) {
auto tempfile = c10::make_tempfile();

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyTestChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyTestChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
ChunkDatasetOptions(
prefetch_count, batch_size, chunk_size /*cache size*/));

auto data_loader = torch::data::make_data_loader(
dataset,
DataLoaderOptions(batch_size).workers(dataloader_worker_count));

for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
int iteration_count = 0;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator, ++iteration_count) {
if ((iteration_count + 1) % save_interval == 0) {
torch::save(*dataset, tempfile.name);

samplers::SequentialSampler new_sampler(0);

// See Note [save/load ChunkDataset as ChunkSampler]
torch::load(new_sampler, tempfile.name);

// Verify save logic. For ChunkDataset, the chunk data is stored in a
// cache inside the dataset. One pool of threads are constantly
// writing to the cache, and a different pool of thread are constantly
// reading from the cache. Due to the nature of asynchronization, at
// the time of get_batch(), which chunk is written to the cache is not
// fully deterministic.
// But we can still calculate a restricted window on the expected
// output, hence verify the logic. In this test, the cache size is
// configured to be the same as chunk size and batch size. So the
// chunk data is written to the cache one by one. Only the current
// batch is retrieved, the next chunk is writen. Now in iteration 0,
// after the first batch is retrieved, when we save the dataset
// statues, there are three possible scenarios for the writer thread:
// 1. it hasn't started loading the next chunk data yet, so the
// sequential sampler index is still 0;
// 2. it started to load the second chunk, so the sequencial sampler
// index is at 1;
// 3. it finished loading the second chunk, and start to load the
// third chunk, because the cache is still fully occupied by the data
// from the second chunk, it is waiting to write to the cache. At this
// point, the sampler index is at 2.
// So now we have a window of [0, 2], which is what we expected the
// sampler to save the index from. Now noted for sequential sampler,
// it advances to the next index automatically in the call next(). So
// when save the index, it saves the next index in stead of the
// current one. In other word, after getting the first index from
// sequential sampler, it already moves to the second index. So when
// we save it, it is the second index we save. As a result,
// we need to advance the window by one. Now we have the expected
// window of [1, 3].
// This analysis applies to all scenarios. So extend it to a more
// general case: the expected saved index should falling into the
// range of [iteration, iteration + 3], which is the validation
// below.
ASSERT_TRUE(
new_sampler.index() >= iteration_count + 1 &&
new_sampler.index() <= iteration_count + 3);
}
}
}
}
}

// Test ChunkDataset load function.
TEST(DataLoaderTest, ChunkDatasetLoad) {
auto tempfile = c10::make_tempfile();

const size_t prefetch_count = 1;
const size_t batch_size = 10;
const size_t dataloader_worker_count = 0;
const size_t save_interval = 2;

DummyChunkDataReader data_reader;
samplers::SequentialSampler sampler(0);

const size_t skipped_chunk = 2;

// Configure sampler to skip 2 chunks
{
sampler.reset(data_reader.chunk_count());
sampler.next(skipped_chunk);

// See Note [save/load ChunkDataset as ChunkSampler]
torch::save(sampler, tempfile.name);
}

// test functionality across epoch boundary. The first epoch should be
// affected by the checkpoint, but the second should start normally.
const int epoch_count = 2;

datasets::SharedBatchDataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
DummyChunkDataReader,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
sampler,
sampler,
datasets::ChunkDatasetOptions(
prefetch_count, batch_size, 20 /*cache size*/));

torch::load(*dataset, tempfile.name);

auto data_loader = torch::data::make_data_loader(
dataset, DataLoaderOptions(batch_size).workers(dataloader_worker_count));

for (int epoch_index = 0; epoch_index < epoch_count; ++epoch_index) {
int iteration_count = 0;

// For the first epoch, the returned batch should be returned from the
// third chunk, because the check point skipped the first two chunks. But
// for the next epoch, it should start from the first batch.
int initial_value = epoch_index == 0 ? 15 : 0;

for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator, ++iteration_count) {
DummyChunkDataReader::BatchType batch = *iterator;

std::vector<int> expected_result;
size_t expected_size = (epoch_index > 0 && iteration_count == 3) ? 5 : 10;
expected_result.resize(expected_size);
std::iota(expected_result.begin(), expected_result.end(), initial_value);

ASSERT_EQ(batch.size(), expected_result.size());
ASSERT_TRUE(
std::equal(batch.begin(), batch.end(), expected_result.begin()));

initial_value += batch_size;
}
}

samplers::SequentialSampler new_sampler(0);

// See Note [save/load ChunkDataset as ChunkSampler]
torch::load(new_sampler, tempfile.name);

ASSERT_EQ(new_sampler.index(), skipped_chunk);
}
32 changes: 25 additions & 7 deletions torch/csrc/api/include/torch/data/datasets/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <queue>
#include <thread>

#include <torch/serialize.h>

namespace torch {
namespace data {
namespace datasets {
Expand Down Expand Up @@ -270,7 +272,7 @@ struct ChunkDatasetOptions {
/// The size of each batch.
TORCH_ARG(size_t, batch_size);

// the capacity of the queue for batch caching.
/// The capacity of the queue for batch caching.
TORCH_ARG(size_t, cache_size) = 2048;
};

Expand Down Expand Up @@ -308,7 +310,8 @@ class ChunkDataset final
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
quit_worker_(false),
running_preloaders_(0) {}
running_preloaders_(0),
load_checkpoint_(false) {}

virtual ~ChunkDataset() {
// stop batch buffer first.
Expand All @@ -332,7 +335,6 @@ class ChunkDataset final
"The requested batch size does not match with the initialized batch size.\n"
" The requested batch size is ", batch_size,
", while the dataset is created with batch size equal to ", options_.batch_size_);

return batch_buffer_->get_batch();
}

Expand All @@ -352,9 +354,11 @@ class ChunkDataset final
free_workers();
preload_threads_.clear();

chunk_reader_.reset();

chunk_sampler_.reset(chunk_reader_.chunk_count());
if (!load_checkpoint_){
chunk_reader_.reset();
chunk_sampler_.reset(chunk_reader_.chunk_count());
load_checkpoint_ = false;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh why do we need this deferred loading business? Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because when we start iterating in a data loader it always calls reset() at the beginning of an epoch?

That's exactly why we need to load here --
if 'load' option is set, the chunk_sampler_ needs to be resumed from the file for the first epoch when reset() is called by data loader. This way, ChunkDataset read chunks starting from the last saved point instead from the beginning. This is the core load logic. After the first epoch, any following epoch should just reset chunk_sampler_ as normal.

Please let me know if you have any questions about this functionality.


// Throw out any existing cached batch in the buffer and re-creates a new
// chunk buffer.
Expand Down Expand Up @@ -385,6 +389,17 @@ class ChunkDataset final
return chunk_sampler_;
}

void save(serialize::OutputArchive& archive) const override {
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.save(archive);
}

void load(serialize::InputArchive& archive) override{
std::lock_guard<std::mutex> lock(chunk_index_guard_);
chunk_sampler_.load(archive);
load_checkpoint_ = true;
}

private:
/// running on worker thread to preload chunk data.
void preloader(size_t id) {
Expand Down Expand Up @@ -455,7 +470,10 @@ class ChunkDataset final
std::atomic<size_t> running_preloaders_;

// mutex to synchronize chunk sampler next() call.
std::mutex chunk_index_guard_;
mutable std::mutex chunk_index_guard_;

// boolean value to indicate whether we need to load the checkpoint for chunk_sampler_.
bool load_checkpoint_;
};
} // namespace datasets
} // namespace data
Expand Down
25 changes: 25 additions & 0 deletions torch/csrc/api/include/torch/data/datasets/stateful.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,32 @@ class StatefulDataset
public:
/// Resets internal state of the dataset.
virtual void reset() = 0;

/// Saves the statefulDataset's state to OutputArchive.
virtual void save(serialize::OutputArchive& archive) const = 0;

/// Deserializes the statefulDataset's state from the `archive`.
virtual void load(serialize::InputArchive& archive) = 0;
};

/// Serializes a statefulDataset to `OutputArchive`.
template <typename... Args>
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const StatefulDataset<Args...>& statefulDataset) {
statefulDataset.save(archive);
return archive;
}

/// Deserializes a statefulDataset from an `InputArchive`.
template <typename... Args>
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
StatefulDataset<Args...>& statefulDataset) {
statefulDataset.load(archive);
return archive;
}

} // namespace datasets
} // namespace data
} // namespace torch