Skip to content

Commit 07fea3f

Browse files
Thiago Crepaldifacebook-github-bot
authored andcommitted
Add new get_batch() method to ChunkDataset API (#21797)
Summary: We plan on generating python bindings for C++ ChunkDataset API using the current Pytorch Dataloader class, which must call get_batch() instead of get_batch(size) This changes doesnt break the current API, just add one more method that will make future extensions easier (WIP) Pull Request resolved: #21797 Differential Revision: D15830522 Pulled By: soumith fbshipit-source-id: 7208f305b48bf65d2783eaff43ff57a05e62c255
1 parent dddc65d commit 07fea3f

File tree

1 file changed

+5
-0
lines changed
  • torch/csrc/api/include/torch/data/datasets

1 file changed

+5
-0
lines changed

torch/csrc/api/include/torch/data/datasets/chunk.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,11 @@ class ChunkDataset final
336336
return batch_buffer_->get_batch();
337337
}
338338

339+
/// Helper method around get_batch as `batch_size` is not strictly necessary
340+
BatchType get_batch() {
341+
return get_batch(options_.batch_size_);
342+
}
343+
339344
/// This will clear any internal state and starts the internal prefetching
340345
/// mechanism for the chunk dataset.
341346
void reset() override {

0 commit comments

Comments
 (0)