Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions test/cpp/api/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2224,4 +2224,85 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
std::equal(result.begin(), result.end(), expected_result.begin()));
}
}
}

TEST(DataLoaderTest, CustomPreprocessPolicy) {
const size_t chunk_size = 5;
const size_t batch_size = 10;

struct D : public datasets::ChunkDataReader<int> {
public:
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
D(size_t chunk_count) : chunk_count_(chunk_count) {}

BatchType read_chunk(size_t chunk_index) override {
BatchType batch_data(chunk_size);
auto rand_gen = []() { return std::rand() % 100; };
std::generate(batch_data.begin(), batch_data.end(), rand_gen);
return batch_data;
}

size_t chunk_count() override {
return chunk_count_;
};

void reset() override{};
size_t chunk_count_;
};

// custom preprocessing policy - sort the data ascendingly
auto sorting_policy = [](std::vector<int>& raw_batch_data) {
std::sort(raw_batch_data.begin(), raw_batch_data.end());
};
std::function<void(std::vector<int>&)> policy_function =
sorting_policy;

const size_t prefetch_count = 1;
const size_t cache_size = 10;
const size_t cross_chunk_shuffle_counts[] = {1, 2};
const size_t chunk_counts[] = {3, 4};

samplers::SequentialSampler chunk_sampler(0);

for (auto chunk_count : chunk_counts) {
for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
D data_reader(chunk_count);

datasets::SharedBatchDataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
samplers::SequentialSampler>>
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
D,
samplers::SequentialSampler,
samplers::SequentialSampler>>(
data_reader,
chunk_sampler,
chunk_sampler,
datasets::ChunkDatasetOptions(
prefetch_count,
batch_size,
cache_size,
cross_chunk_shuffle_count),
policy_function);

auto data_loader = torch::data::make_data_loader(
dataset, DataLoaderOptions(batch_size).workers(0));

std::vector<int> result;
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator) {
auto batch_result = *iterator;
if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
for (int i = 0; i < batch_result.size(); i += chunk_size) {
ASSERT_TRUE(std::is_sorted(
batch_result.begin() + i,
batch_result.begin() + i + chunk_size));
}
} else {
ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
}
}
}
}
}
20 changes: 19 additions & 1 deletion torch/csrc/api/include/torch/data/datasets/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,14 @@ class ChunkDataset final
ChunkReader chunk_reader,
ChunkSampler chunk_sampler,
ExampleSampler example_sampler,
ChunkDatasetOptions options)
ChunkDatasetOptions options,
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
std::function<void(UnwrappedBatchType&)>())
: chunk_reader_(std::move(chunk_reader)),
chunk_sampler_(std::move(chunk_sampler)),
example_sampler_(std::move(example_sampler)),
options_(std::move(options)),
preprocessing_policy_(preprocessing_policy),
quit_worker_(false),
running_preloaders_(0),
load_checkpoint_(false) {}
Expand Down Expand Up @@ -436,6 +439,9 @@ class ChunkDataset final
std::move(
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
}
if (preprocessing_policy_) {
preprocessing_policy_(data);
}
if (!data.empty()) { // skip empty chunks.
batch_buffer_->add_chunk_data(std::move(data));
}
Expand Down Expand Up @@ -483,6 +489,18 @@ class ChunkDataset final
/// The options the Dataset was configured with.
const ChunkDatasetOptions options_;

// function pointer wrapper to apply custom processing over chunk data. This is
// considered an advanced parameter for developers who want to apply a
// pre-process to the chunk data before sampling into minibatch.
// Different than the collate function, this policy is applied on the chunk
// level, instead of minibatch level. When a chunk of data is loaded (multiple
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
// applied to the full loaded data. It is useful if developers want to
// perform pre-processing (like bucketing) to the chunk data before
// example sampler samples the data. By default it's an empty pointer and no
// action will be taken.
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;

// indicate whether the worker thread can be teared down
std::atomic<bool> quit_worker_;

Expand Down