Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions test/cpp/api/dataloader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ TEST(DataTest, TransformCallsGetApplyCorrectly) {

// dummy chunk data reader with 3 chunks and 35 examples in total. Each chunk
// contains 10, 5, 20 examples respectively.

struct DummyChunkDataReader
: public datasets::ChunkDataReader<std::vector<int>> {
: public datasets::ChunkDataReader<int> {
public:
using BatchType = std::vector<int>;
using BatchType = datasets::ChunkDataReader<int>::ChunkType;
using DataType = datasets::ChunkDataReader<int>::ExampleType;

/// Read an entire chunk.
BatchType read_chunk(size_t chunk_index) override {
Expand Down Expand Up @@ -1650,7 +1652,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatch) {
for (auto iterator = data_loader->begin();
iterator != data_loader->end();
++iterator, ++iteration_count) {
std::vector<int>& batch = *iterator;
DummyChunkDataReader::BatchType& batch = *iterator;
ASSERT_EQ(batch.size(), batch_size);

// When prefetch_count is equal to 1 and no worker thread, the batch
Expand Down Expand Up @@ -1709,9 +1711,9 @@ TEST(DataLoaderTest, ChunkDataSetWithBatchSizeMismatch) {

TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
struct DummyEmptyChunkDataReader
: datasets::ChunkDataReader<std::vector<int>> {
: datasets::ChunkDataReader<int> {
public:
using BatchType = std::vector<int>;
using BatchType = datasets::ChunkDataReader<int>::ChunkType;

BatchType read_chunk(size_t chunk_index) override {
return {};
Expand Down Expand Up @@ -1752,9 +1754,9 @@ TEST(DataLoaderTest, ChunkDataSetWithEmptyBatch) {
}

TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {
struct D : public datasets::ChunkDataReader<std::vector<int>> {
struct D : public datasets::ChunkDataReader<int> {
public:
using BatchType = std::vector<int>;
using BatchType = datasets::ChunkDataReader<int>::ChunkType;

BatchType read_chunk(size_t chunk_index) override {
BatchType batch_data(10, 0);
Expand Down Expand Up @@ -1791,7 +1793,7 @@ TEST(DataLoaderTest, ChunkDataSetGetBatchWithUnevenBatchSize) {

for (auto iterator = data_loader->begin(); iterator != data_loader->end();
++iterator) {
std::vector<int> batch = *iterator;
DummyChunkDataReader::BatchType batch = *iterator;
auto batch_size = batch.size();
if (batch_size == 17) {
ASSERT_TRUE(batch.size() == 17 || batch.size() == 3);
Expand Down Expand Up @@ -1825,8 +1827,8 @@ TEST(DataLoaderTest, CanAccessChunkSamplerWithChunkDataSet) {
samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();

auto data_loader = torch::data::make_data_loader(
dataset.map(transforms::BatchLambda<std::vector<int>, int>(
[](std::vector<int> batch) {
dataset.map(transforms::BatchLambda<DummyChunkDataReader::BatchType, DummyChunkDataReader::DataType>(
[](DummyChunkDataReader::BatchType batch) {
return std::accumulate(batch.begin(), batch.end(), 0);
})),
DataLoaderOptions(batch_size).workers(0));
Expand Down Expand Up @@ -1869,13 +1871,13 @@ TEST(DataLoaderTest, ChunkDatasetDoesNotHang) {
samplers::SequentialSampler& chunk_sampler = dataset->chunk_sampler();

auto data_loader = torch::data::make_data_loader(
dataset.map(transforms::BatchLambda<std::vector<int>, int>(
[](std::vector<int> batch) {
dataset.map(transforms::BatchLambda<DummyChunkDataReader::BatchType, DummyChunkDataReader::DataType>(
[](DummyChunkDataReader::BatchType batch) {
return std::accumulate(batch.begin(), batch.end(), 0);
})),
DataLoaderOptions(batch_size).workers(0));
// simply creates the iterator but no iteration. chunk preloaders are waiting
// to fill the batch buffer but it is not draining. Still we need to exit
// cleanly.
auto iterator = data_loader->begin();
}
}
12 changes: 9 additions & 3 deletions torch/csrc/api/include/torch/data/datasets/chunk.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#pragma once

#include <torch/arg.h>
#include <torch/csrc/utils/memory.h>
#include <torch/data/datasets/stateful.h>
#include <torch/data/samplers.h>
#include <queue>
#include <thread>

namespace torch {
namespace data {
Expand All @@ -12,10 +17,11 @@ namespace datasets {
/// A chunk could be an entire file, such as an audio data file or an image,
/// or part of a file in the case of a large text-file split based on seek
/// positions.
template <typename Chunk = std::vector<Example<>>>
template <typename ExampleType_, typename ChunkType_ = std::vector<ExampleType_>>
class ChunkDataReader {
public:
using ChunkType = Chunk;
using ChunkType = ChunkType_;
using ExampleType = ExampleType_;

/// Read an entire chunk.
virtual ChunkType read_chunk(size_t chunk_index) = 0;
Expand All @@ -34,7 +40,7 @@ namespace detail {
/// return. If the cache is empty, it either waits to load more chunks or return
/// null if all chunks are loaded.
template <
typename UnwrappedBatch = std::vector<Example<>>,
typename UnwrappedBatch,
typename ExampleSampler = samplers::RandomSampler>
class BatchDataBuffer {
public:
Expand Down