Skip to content

Commit 3d4d7b9

Browse files
Thiago Crepaldifacebook-github-bot
authored andcommitted
Refactor ChunkDataReader API + fix missing headers (#19485)
Summary: This PR restricts the BatchType template argument of ChunkDataReader to STL vectors only. Internally, ChunkDataReader was assuming BatchType was a vector, but the user could pass any type to the template argument, leading to compiling issues during CPP extensions. Additionally to the proposed API change, this PR adds missing include headers to chunk.h. Currently the current implementation works but if users try to create C++ extensions that implements new ChunkDataReaders to be along with the existing ChunkDataset, the build will fail due to the missing headers. In terms of functionality, nothing has changed. This PR simply makes the implementation slightly more robust for future extensions. Pull Request resolved: #19485 Differential Revision: D15261725 Pulled By: soumith fbshipit-source-id: 38c9465d665392ae6a2d12c5a520a4f501e1a6ca
1 parent bed1d7d commit 3d4d7b9

File tree

2 files changed

+24
-16
lines changed

2 files changed

+24
-16
lines changed

test/cpp/api/dataloader.cpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@ TEST(DataTest, TransformCallsGetApplyCorrectly) {
6262

6363
// dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
6464
// contains 10, 5, 20 examples respectively.
65+
6566
struct DummyChunkDataReader
66-
: public datasets::ChunkDataReader<std::vector<int>> {
67+
: public datasets::ChunkDataReader<int> {
6768
public:
68-
using BatchType = std::vector<int>;
69+
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
70+
using DataType = datasets::ChunkDataReader<int>::ExampleType;
6971

7072
/// Read an entire chunk.
7173
BatchType read_chunk(size_t chunk_index) override {
@@ -1650,7 +1652,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatch) {
16501652
for (auto iterator = data_loader->begin();
16511653
iterator != data_loader->end();
16521654
++iterator, ++iteration_count) {
1653-
std::vector<int>& batch = *iterator;
1655+
DummyChunkDataReader::BatchType& batch = *iterator;
16541656
ASSERT_EQ(batch.size(), batch_size);
16551657

16561658
// When prefetch_count is equal to 1 and no worker thread, the batch
@@ -1709,9 +1711,9 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {
17091711

17101712
TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
17111713
struct DummyEmptyChunkDataReader
1712-
: datasets::ChunkDataReader<std::vector<int>> {
1714+
: datasets::ChunkDataReader<int> {
17131715
public:
1714-
using BatchType = std::vector<int>;
1716+
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
17151717

17161718
BatchType read_chunk(size_t chunk_index) override {
17171719
return {};
@@ -1752,9 +1754,9 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
17521754
}
17531755

17541756
TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
1755-
struct D : public datasets::ChunkDataReader<std::vector<int>> {
1757+
struct D : public datasets::ChunkDataReader<int> {
17561758
public:
1757-
using BatchType = std::vector<int>;
1759+
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
17581760

17591761
BatchType read_chunk(size_t chunk_index) override {
17601762
BatchType batch_data(10, 0);
@@ -1791,7 +1793,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
17911793

17921794
for (auto iterator = data_loader->begin(); iterator != data_loader->end();
17931795
++iterator) {
1794-
std::vector<int> batch = *iterator;
1796+
DummyChunkDataReader::BatchType batch = *iterator;
17951797
auto batch_size = batch.size();
17961798
if (batch_size == 17) {
17971799
ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
@@ -1825,8 +1827,8 @@ TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
18251827
samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
18261828

18271829
auto data_loader = torch::data::make_data_loader(
1828-
dataset.map(transforms::BatchLambda<std::vector<int>, int>(
1829-
[](std::vector<int> batch) {
1830+
dataset.map(transforms::BatchLambda<DummyChunkDataReader::BatchType, DummyChunkDataReader::DataType>(
1831+
[](DummyChunkDataReader::BatchType batch) {
18301832
return std::accumulate(batch.begin(), batch.end(), 0);
18311833
})),
18321834
DataLoaderOptions(batch_size).workers(0));
@@ -1869,13 +1871,13 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
18691871
samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();
18701872

18711873
auto data_loader = torch::data::make_data_loader(
1872-
dataset.map(transforms::BatchLambda<std::vector<int>, int>(
1873-
[](std::vector<int> batch) {
1874+
dataset.map(transforms::BatchLambda<DummyChunkDataReader::BatchType, DummyChunkDataReader::DataType>(
1875+
[](DummyChunkDataReader::BatchType batch) {
18741876
return std::accumulate(batch.begin(), batch.end(), 0);
18751877
})),
18761878
DataLoaderOptions(batch_size).workers(0));
18771879
// simply creates the iterator but no iteration. chunk preloaders are waiting
18781880
// to fill the batch buffer but it is not draining. Still we need to exit
18791881
// cleanly.
18801882
auto iterator = data_loader->begin();
1881-
}
1883+
}

torch/csrc/api/include/torch/data/datasets/chunk.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#pragma once
22

3+
#include <torch/arg.h>
4+
#include <torch/csrc/utils/memory.h>
35
#include <torch/data/datasets/stateful.h>
6+
#include <torch/data/samplers.h>
7+
#include <queue>
8+
#include <thread>
49

510
namespace torch {
611
namespace data {
@@ -12,10 +17,11 @@ namespace datasets {
1217
/// A chunk could be an entire file, such as an audio data file or an image,
1318
/// or part of a file in the case of a large text-file split based on seek
1419
/// positions.
15-
template <typename Chunk = std::vector<Example<>>>
20+
template <typename ExampleType_, typename ChunkType_ = std::vector<ExampleType_>>
1621
class ChunkDataReader {
1722
public:
18-
using ChunkType = Chunk;
23+
using ChunkType = ChunkType_;
24+
using ExampleType = ExampleType_;
1925

2026
/// Read an entire chunk.
2127
virtual ChunkType read_chunk(size_t chunk_index) = 0;
@@ -34,7 +40,7 @@ namespace detail {
3440
/// return. If the cache is empty, it either waits to load more chunks or return
3541
/// null if all chunks are loaded.
3642
template <
37-
typename UnwrappedBatch = std::vector<Example<>>,
43+
typename UnwrappedBatch,
3844
typename ExampleSampler = samplers::RandomSampler>
3945
class BatchDataBuffer {
4046
public:

0 commit comments

Comments
 (0)