Skip to content

Commit 32b6986

Browse files
committed
Add support for multiple datasets in the trainer. The datasets are interleaved at binpack chunk granularity.
Handling cyclic binpack reading was moved from `BinpackSfenInputParallelStream` to `CompressedTrainingDataEntryParallelReader` to allow for cycling each dataset individually. train.py accepts any number of positional arguments - paths to the datasets to use. Validation will use the same datasets unless overriden by `--validation-data`, which can be present multiple times (or have multiple values) to specify multiple datasets. easy_train.py now can have multiple instances of `--training-dataset` and `--validation-dataset` (or accept multiple values for each). C API changed, additional helper functions were made to wrap conversion of string list to array of char* for `create_fen_batch_stream` and `create_sparse_batch_stream`. Also kinda deprecate .bin, currently won't work with multiple datasets. No one was using it anyway, we should remove it (at least don't allow it in the trainer) next time.
1 parent 2aa01ec commit 32b6986

File tree

7 files changed

+181
-124
lines changed

7 files changed

+181
-124
lines changed

lib/nnue_training_data_formats.h

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6726,6 +6726,12 @@ namespace binpack
67266726
}
67276727
}
67286728

6729+
inline std::ifstream::pos_type filesize(const char* filename)
6730+
{
6731+
std::ifstream in(filename, std::ifstream::ate | std::ifstream::binary);
6732+
return in.tellg();
6733+
}
6734+
67296735
struct CompressedTrainingDataFile
67306736
{
67316737
struct Header
@@ -6737,12 +6743,15 @@ namespace binpack
67376743
m_path(std::move(path)),
67386744
m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om)
67396745
{
6746+
// Racey but who cares
6747+
m_sizeBytes = filesize(m_path.c_str());
67406748
}
67416749

67426750
void append(const char* data, std::uint32_t size)
67436751
{
67446752
writeChunkHeader({size});
67456753
m_file.write(data, size);
6754+
m_sizeBytes += size + 8;
67466755
}
67476756

67486757
[[nodiscard]] bool hasNextChunk()
@@ -6756,6 +6765,11 @@ namespace binpack
67566765
return !m_file.eof();
67576766
}
67586767

6768+
void seek_to_start()
6769+
{
6770+
m_file.seekg(0);
6771+
}
6772+
67596773
[[nodiscard]] std::vector<unsigned char> readNextChunk()
67606774
{
67616775
auto size = readChunkHeader().chunkSize;
@@ -6764,9 +6778,15 @@ namespace binpack
67646778
return data;
67656779
}
67666780

6781+
[[nodiscard]] std::size_t sizeBytes() const
6782+
{
6783+
return m_sizeBytes;
6784+
}
6785+
67676786
private:
67686787
std::string m_path;
67696788
std::fstream m_file;
6789+
std::size_t m_sizeBytes;
67706790

67716791
void writeChunkHeader(Header h)
67726792
{
@@ -7558,21 +7578,32 @@ namespace binpack
75587578

75597579
CompressedTrainingDataEntryParallelReader(
75607580
int concurrency,
7561-
std::string path,
7581+
std::vector<std::string> paths,
75627582
std::ios_base::openmode om = std::ios_base::app,
7583+
bool cyclic = false,
75637584
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr
75647585
) :
75657586
m_concurrency(concurrency),
7566-
m_inputFile(path, om),
75677587
m_bufferOffset(0),
7588+
m_cyclic(cyclic),
75687589
m_skipPredicate(std::move(skipPredicate))
75697590
{
75707591
m_numRunningWorkers.store(0);
7571-
if (!m_inputFile.hasNextChunk())
7592+
std::vector<double> sizes; // discrete distribution wants double weights
7593+
for (const auto& path : paths)
75727594
{
7573-
return;
7595+
auto& file = m_inputFiles.emplace_back(path, om);
7596+
7597+
if (!file.hasNextChunk())
7598+
{
7599+
return;
7600+
}
7601+
7602+
sizes.emplace_back(static_cast<double>(file.sizeBytes()));
75747603
}
75757604

7605+
m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end());
7606+
75767607
m_stopFlag.store(false);
75777608

75787609
auto worker = [this]()
@@ -7742,8 +7773,10 @@ namespace binpack
77427773

77437774
private:
77447775
int m_concurrency;
7745-
CompressedTrainingDataFile m_inputFile;
7776+
std::vector<CompressedTrainingDataFile> m_inputFiles;
7777+
std::discrete_distribution<> m_inputFileDistribution;
77467778
std::atomic_int m_numRunningWorkers;
7779+
bool m_cyclic;
77477780

77487781
static constexpr int threadBufferSize = 256 * 256 * 16;
77497782

@@ -7763,17 +7796,24 @@ namespace binpack
77637796
{
77647797
if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size())
77657798
{
7799+
auto& prng = rng::get_thread_local_rng();
7800+
const std::size_t fileId = m_inputFileDistribution(prng);
7801+
auto& inputFile = m_inputFiles[fileId];
7802+
77667803
std::unique_lock lock(m_fileMutex);
77677804

7768-
if (!m_inputFile.hasNextChunk())
7769-
{
7770-
return true;
7771-
}
7772-
else
7805+
if (!inputFile.hasNextChunk())
77737806
{
7774-
m_chunk = m_inputFile.readNextChunk();
7775-
m_offset = 0;
7807+
if (m_cyclic)
7808+
{
7809+
inputFile.seek_to_start();
7810+
}
7811+
else
7812+
return true;
77767813
}
7814+
7815+
m_chunk = inputFile.readNextChunk();
7816+
m_offset = 0;
77777817
}
77787818

77797819
return false;

lib/nnue_training_data_stream.h

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ namespace training_data {
183183
static constexpr auto openmode = std::ios::in | std::ios::binary;
184184
static inline const std::string extension = "binpack";
185185

186-
BinpackSfenInputParallelStream(int concurrency, std::string filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187-
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filename, openmode, skipPredicate)),
188-
m_filename(filename),
186+
BinpackSfenInputParallelStream(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187+
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filenames, openmode, cyclic, skipPredicate)),
188+
m_filenames(filenames),
189189
m_concurrency(concurrency),
190190
m_eof(false),
191191
m_cyclic(cyclic),
@@ -199,12 +199,6 @@ namespace training_data {
199199
auto v = m_stream->next();
200200
if (!v.has_value())
201201
{
202-
if (m_cyclic)
203-
{
204-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
205-
return m_stream->next();
206-
}
207-
208202
m_eof = true;
209203
return std::nullopt;
210204
}
@@ -217,32 +211,7 @@ namespace training_data {
217211
auto k = m_stream->fill(v, n);
218212
if (n != k)
219213
{
220-
if (m_cyclic)
221-
{
222-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
223-
n -= k;
224-
k = m_stream->fill(v, n);
225-
if (k == 0)
226-
{
227-
// No data in the file
228-
m_eof = true;
229-
return;
230-
}
231-
else if (k == n)
232-
{
233-
// We're done
234-
return;
235-
}
236-
else
237-
{
238-
// We need to read again
239-
this->fill(v, n - k);
240-
}
241-
}
242-
else
243-
{
244-
m_eof = true;
245-
}
214+
m_eof = true;
246215
}
247216
}
248217

@@ -255,7 +224,7 @@ namespace training_data {
255224

256225
private:
257226
std::unique_ptr<binpack::CompressedTrainingDataEntryParallelReader> m_stream;
258-
std::string m_filename;
227+
std::vector<std::string> m_filenames;
259228
int m_concurrency;
260229
bool m_eof;
261230
bool m_cyclic;
@@ -272,13 +241,13 @@ namespace training_data {
272241
return nullptr;
273242
}
274243

275-
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::string& filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
244+
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
276245
{
277246
// TODO (low priority): optimize and parallelize .bin reading.
278-
if (has_extension(filename, BinSfenInputStream::extension))
279-
return std::make_unique<BinSfenInputStream>(filename, cyclic, std::move(skipPredicate));
280-
else if (has_extension(filename, BinpackSfenInputParallelStream::extension))
281-
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filename, cyclic, std::move(skipPredicate));
247+
if (has_extension(filenames[0], BinSfenInputStream::extension))
248+
return std::make_unique<BinSfenInputStream>(filenames[0], cyclic, std::move(skipPredicate));
249+
else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension))
250+
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filenames, cyclic, std::move(skipPredicate));
282251

283252
return nullptr;
284253
}

nnue_dataset.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ def get_fens(self):
6767
return strings
6868

6969
FenBatchPtr = ctypes.POINTER(FenBatch)
70-
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int param_index)
70+
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
7171
create_fen_batch_stream = dll.create_fen_batch_stream
7272
create_fen_batch_stream.restype = ctypes.c_void_p
73-
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
73+
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
7474
destroy_fen_batch_stream = dll.destroy_fen_batch_stream
7575
destroy_fen_batch_stream.argtypes = [ctypes.c_void_p]
7676

77+
def make_fen_batch_stream(concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
78+
filenames_ = (ctypes.c_char_p * len(filenames))()
79+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
80+
return create_fen_batch_stream(concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
81+
7782
fetch_next_fen_batch = dll.fetch_next_fen_batch
7883
fetch_next_fen_batch.restype = FenBatchPtr
7984
fetch_next_fen_batch.argtypes = [ctypes.c_void_p]
@@ -103,9 +108,9 @@ def __init__(
103108
self.param_index = param_index
104109

105110
if batch_size:
106-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
111+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
107112
else:
108-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
113+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
109114

110115
def __iter__(self):
111116
return self
@@ -131,7 +136,7 @@ def __init__(
131136
destroy_stream,
132137
fetch_next,
133138
destroy_part,
134-
filename,
139+
filenames,
135140
cyclic,
136141
num_workers,
137142
batch_size=None,
@@ -147,7 +152,7 @@ def __init__(
147152
self.destroy_stream = destroy_stream
148153
self.fetch_next = fetch_next
149154
self.destroy_part = destroy_part
150-
self.filename = filename.encode('utf-8')
155+
self.filenames = filenames
151156
self.cyclic = cyclic
152157
self.num_workers = num_workers
153158
self.batch_size = batch_size
@@ -158,9 +163,9 @@ def __init__(
158163
self.device = device
159164

160165
if batch_size:
161-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
166+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
162167
else:
163-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
168+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
164169

165170
def __iter__(self):
166171
return self
@@ -178,14 +183,19 @@ def __next__(self):
178183
def __del__(self):
179184
self.destroy_stream(self.stream)
180185

181-
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
186+
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic,
182187
# bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
183188
create_sparse_batch_stream = dll.create_sparse_batch_stream
184189
create_sparse_batch_stream.restype = ctypes.c_void_p
185-
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
190+
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
186191
destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream
187192
destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p]
188193

194+
def make_sparse_batch_stream(feature_set, concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
195+
filenames_ = (ctypes.c_char_p * len(filenames))()
196+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
197+
return create_sparse_batch_stream(feature_set, concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
198+
189199
fetch_next_sparse_batch = dll.fetch_next_sparse_batch
190200
fetch_next_sparse_batch.restype = SparseBatchPtr
191201
fetch_next_sparse_batch.argtypes = [ctypes.c_void_p]
@@ -211,14 +221,14 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
211221
return b
212222

213223
class SparseBatchProvider(TrainingDataProvider):
214-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
224+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
215225
super(SparseBatchProvider, self).__init__(
216226
feature_set,
217-
create_sparse_batch_stream,
227+
make_sparse_batch_stream,
218228
destroy_sparse_batch_stream,
219229
fetch_next_sparse_batch,
220230
destroy_sparse_batch,
221-
filename,
231+
filenames,
222232
cyclic,
223233
num_workers,
224234
batch_size,
@@ -230,10 +240,10 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
230240
device)
231241

232242
class SparseBatchDataset(torch.utils.data.IterableDataset):
233-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
243+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
234244
super(SparseBatchDataset).__init__()
235245
self.feature_set = feature_set
236-
self.filename = filename
246+
self.filenames = filenames
237247
self.batch_size = batch_size
238248
self.cyclic = cyclic
239249
self.num_workers = num_workers
@@ -245,7 +255,7 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
245255
self.device = device
246256

247257
def __iter__(self):
248-
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
258+
return SparseBatchProvider(self.feature_set, self.filenames, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
249259
filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, early_fen_skipping = self.early_fen_skipping, param_index=self.param_index, device=self.device)
250260

251261
class FixedNumBatchesDataset(Dataset):

0 commit comments

Comments
 (0)