Skip to content

Commit 31f1928

Browse files
xzhu1900facebook-github-bot
authored andcommitted
add sorting policy to ChunkDataset (#23053)
Summary: Add a sorting policy to ChunkDataset. This is considered an advanced parameter for developers who want to apply a 'sorting policy' to the chunk data before sampling into minibatch. Different than the collate method, this policy is applied on the chunk level instead of minibatch level. When a chunk of data is loaded (multiple chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is targeting to the full loaded data. It will be useful if developers want to perform some pre-processing (like bucketing) to the chunk data before example sampler samples the data. Pull Request resolved: #23053 Differential Revision: D16537692 Pulled By: colesbury fbshipit-source-id: cd21ed40ab787a18b8c6dd304e5b806a7a45e6ba
1 parent a356276 commit 31f1928

File tree

2 files changed

+100
-1
lines changed

2 files changed

+100
-1
lines changed

test/cpp/api/dataloader.cpp

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2224,4 +2224,85 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) {
22242224
std::equal(result.begin(), result.end(), expected_result.begin()));
22252225
}
22262226
}
2227+
}
2228+
2229+
TEST(DataLoaderTest, CustomPreprocessPolicy) {
2230+
const size_t chunk_size = 5;
2231+
const size_t batch_size = 10;
2232+
2233+
struct D : public datasets::ChunkDataReader<int> {
2234+
public:
2235+
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
2236+
D(size_t chunk_count) : chunk_count_(chunk_count) {}
2237+
2238+
BatchType read_chunk(size_t chunk_index) override {
2239+
BatchType batch_data(chunk_size);
2240+
auto rand_gen = []() { return std::rand() % 100; };
2241+
std::generate(batch_data.begin(), batch_data.end(), rand_gen);
2242+
return batch_data;
2243+
}
2244+
2245+
size_t chunk_count() override {
2246+
return chunk_count_;
2247+
};
2248+
2249+
void reset() override{};
2250+
size_t chunk_count_;
2251+
};
2252+
2253+
// custom preprocessing policy - sort the data ascendingly
2254+
auto sorting_policy = [](std::vector<int>& raw_batch_data) {
2255+
std::sort(raw_batch_data.begin(), raw_batch_data.end());
2256+
};
2257+
std::function<void(std::vector<int>&)> policy_function =
2258+
sorting_policy;
2259+
2260+
const size_t prefetch_count = 1;
2261+
const size_t cache_size = 10;
2262+
const size_t cross_chunk_shuffle_counts[] = {1, 2};
2263+
const size_t chunk_counts[] = {3, 4};
2264+
2265+
samplers::SequentialSampler chunk_sampler(0);
2266+
2267+
for (auto chunk_count : chunk_counts) {
2268+
for (auto cross_chunk_shuffle_count : cross_chunk_shuffle_counts) {
2269+
D data_reader(chunk_count);
2270+
2271+
datasets::SharedBatchDataset<datasets::ChunkDataset<
2272+
D,
2273+
samplers::SequentialSampler,
2274+
samplers::SequentialSampler>>
2275+
dataset = datasets::make_shared_dataset<datasets::ChunkDataset<
2276+
D,
2277+
samplers::SequentialSampler,
2278+
samplers::SequentialSampler>>(
2279+
data_reader,
2280+
chunk_sampler,
2281+
chunk_sampler,
2282+
datasets::ChunkDatasetOptions(
2283+
prefetch_count,
2284+
batch_size,
2285+
cache_size,
2286+
cross_chunk_shuffle_count),
2287+
policy_function);
2288+
2289+
auto data_loader = torch::data::make_data_loader(
2290+
dataset, DataLoaderOptions(batch_size).workers(0));
2291+
2292+
std::vector<int> result;
2293+
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
2294+
++iterator) {
2295+
auto batch_result = *iterator;
2296+
if (batch_result.size() > chunk_size * cross_chunk_shuffle_count) {
2297+
for (int i = 0; i < batch_result.size(); i += chunk_size) {
2298+
ASSERT_TRUE(std::is_sorted(
2299+
batch_result.begin() + i,
2300+
batch_result.begin() + i + chunk_size));
2301+
}
2302+
} else {
2303+
ASSERT_TRUE(std::is_sorted(batch_result.begin(), batch_result.end()));
2304+
}
2305+
}
2306+
}
2307+
}
22272308
}

torch/csrc/api/include/torch/data/datasets/chunk.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,11 +320,14 @@ class ChunkDataset final
320320
ChunkReader chunk_reader,
321321
ChunkSampler chunk_sampler,
322322
ExampleSampler example_sampler,
323-
ChunkDatasetOptions options)
323+
ChunkDatasetOptions options,
324+
std::function<void(UnwrappedBatchType&)> preprocessing_policy =
325+
std::function<void(UnwrappedBatchType&)>())
324326
: chunk_reader_(std::move(chunk_reader)),
325327
chunk_sampler_(std::move(chunk_sampler)),
326328
example_sampler_(std::move(example_sampler)),
327329
options_(std::move(options)),
330+
preprocessing_policy_(preprocessing_policy),
328331
quit_worker_(false),
329332
running_preloaders_(0),
330333
load_checkpoint_(false) {}
@@ -436,6 +439,9 @@ class ChunkDataset final
436439
std::move(
437440
chunk_data.begin(), chunk_data.end(), std::back_inserter(data));
438441
}
442+
if (preprocessing_policy_) {
443+
preprocessing_policy_(data);
444+
}
439445
if (!data.empty()) { // skip empty chunks.
440446
batch_buffer_->add_chunk_data(std::move(data));
441447
}
@@ -483,6 +489,18 @@ class ChunkDataset final
483489
/// The options the Dataset was configured with.
484490
const ChunkDatasetOptions options_;
485491

492+
// function pointer wrapper to apply custom processing over chunk data. This is
493+
// considered an advanced parameter for developers who want to apply a
494+
// pre-process to the chunk data before sampling into minibatch.
495+
// Different than the collate function, this policy is applied on the chunk
496+
// level, instead of minibatch level. When a chunk of data is loaded (multiple
497+
// chunks if cross_chunk_shuffle_count_ is greater than 1), this policy is
498+
// applied to the full loaded data. It is useful if developers want to
499+
// perform pre-processing (like bucketing) to the chunk data before
500+
// example sampler samples the data. By default it's an empty pointer and no
501+
// action will be taken.
502+
std::function<void(UnwrappedBatchType&)> preprocessing_policy_;
503+
486504
// indicate whether the worker thread can be teared down
487505
std::atomic<bool> quit_worker_;
488506

0 commit comments

Comments
 (0)