Skip to content

Commit cdfa110

Browse files
committed
Add support for multiple datasets in the trainer. The datasets are interleaved at binpack chunk granularity.
Handling cyclic binpack reading was moved from `BinpackSfenInputParallelStream` to `CompressedTrainingDataEntryParallelReader` to allow for cycling each dataset individually. train.py accepts any number of positional arguments - paths to the datasets to use. Validation will use the same datasets unless overriden by `--validation-data`, which can be present multiple times (or have multiple values) to specify multiple datasets. easy_train.py now can have multiple instances of `--training-dataset` and `--validation-dataset` (or accept multiple values for each). C API changed, additional helper functions were made to wrap conversion of string list to array of char* for `create_fen_batch_stream` and `create_sparse_batch_stream`. Also kinda deprecate .bin, currently won't work with multiple datasets. No one was using it anyway, we should remove it (at least don't allow it in the trainer) next time.
1 parent 2aa01ec commit cdfa110

File tree

6 files changed

+168
-123
lines changed

6 files changed

+168
-123
lines changed

lib/nnue_training_data_formats.h

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
4747
#include <thread>
4848
#include <mutex>
4949
#include <random>
50+
#include <filesystem>
5051

5152
#include "rng.h"
5253

@@ -6737,12 +6738,15 @@ namespace binpack
67376738
m_path(std::move(path)),
67386739
m_file(m_path, std::ios_base::binary | std::ios_base::in | std::ios_base::out | om)
67396740
{
6741+
// Racey but who cares
6742+
m_sizeBytes = std::filesystem::file_size(m_path);
67406743
}
67416744

67426745
void append(const char* data, std::uint32_t size)
67436746
{
67446747
writeChunkHeader({size});
67456748
m_file.write(data, size);
6749+
m_sizeBytes += size + 8;
67466750
}
67476751

67486752
[[nodiscard]] bool hasNextChunk()
@@ -6756,6 +6760,11 @@ namespace binpack
67566760
return !m_file.eof();
67576761
}
67586762

6763+
void seek_to_start()
6764+
{
6765+
m_file.seekg(0);
6766+
}
6767+
67596768
[[nodiscard]] std::vector<unsigned char> readNextChunk()
67606769
{
67616770
auto size = readChunkHeader().chunkSize;
@@ -6764,9 +6773,15 @@ namespace binpack
67646773
return data;
67656774
}
67666775

6776+
[[nodiscard]] std::size_t sizeBytes() const
6777+
{
6778+
return m_sizeBytes;
6779+
}
6780+
67676781
private:
67686782
std::string m_path;
67696783
std::fstream m_file;
6784+
std::size_t m_sizeBytes;
67706785

67716786
void writeChunkHeader(Header h)
67726787
{
@@ -7558,21 +7573,32 @@ namespace binpack
75587573

75597574
CompressedTrainingDataEntryParallelReader(
75607575
int concurrency,
7561-
std::string path,
7576+
std::vector<std::string> paths,
75627577
std::ios_base::openmode om = std::ios_base::app,
7578+
bool cyclic = false,
75637579
std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr
75647580
) :
75657581
m_concurrency(concurrency),
7566-
m_inputFile(path, om),
75677582
m_bufferOffset(0),
7583+
m_cyclic(cyclic),
75687584
m_skipPredicate(std::move(skipPredicate))
75697585
{
75707586
m_numRunningWorkers.store(0);
7571-
if (!m_inputFile.hasNextChunk())
7587+
std::vector<double> sizes; // discrete distribution wants double weights
7588+
for (const auto& path : paths)
75727589
{
7573-
return;
7590+
auto& file = m_inputFiles.emplace_back(path, om);
7591+
7592+
if (!file.hasNextChunk())
7593+
{
7594+
return;
7595+
}
7596+
7597+
sizes.emplace_back(static_cast<double>(file.sizeBytes()));
75747598
}
75757599

7600+
m_inputFileDistribution = std::discrete_distribution<>(sizes.begin(), sizes.end());
7601+
75767602
m_stopFlag.store(false);
75777603

75787604
auto worker = [this]()
@@ -7742,8 +7768,10 @@ namespace binpack
77427768

77437769
private:
77447770
int m_concurrency;
7745-
CompressedTrainingDataFile m_inputFile;
7771+
std::vector<CompressedTrainingDataFile> m_inputFiles;
7772+
std::discrete_distribution<> m_inputFileDistribution;
77467773
std::atomic_int m_numRunningWorkers;
7774+
bool m_cyclic;
77477775

77487776
static constexpr int threadBufferSize = 256 * 256 * 16;
77497777

@@ -7763,17 +7791,24 @@ namespace binpack
77637791
{
77647792
if (m_offset + sizeof(PackedTrainingDataEntry) + 2 > m_chunk.size())
77657793
{
7794+
auto& prng = rng::get_thread_local_rng();
7795+
const std::size_t fileId = m_inputFileDistribution(prng);
7796+
auto& inputFile = m_inputFiles[fileId];
7797+
77667798
std::unique_lock lock(m_fileMutex);
77677799

7768-
if (!m_inputFile.hasNextChunk())
7769-
{
7770-
return true;
7771-
}
7772-
else
7800+
if (!inputFile.hasNextChunk())
77737801
{
7774-
m_chunk = m_inputFile.readNextChunk();
7775-
m_offset = 0;
7802+
if (m_cyclic)
7803+
{
7804+
inputFile.seek_to_start();
7805+
}
7806+
else
7807+
return true;
77767808
}
7809+
7810+
m_chunk = inputFile.readNextChunk();
7811+
m_offset = 0;
77777812
}
77787813

77797814
return false;

lib/nnue_training_data_stream.h

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,9 @@ namespace training_data {
183183
static constexpr auto openmode = std::ios::in | std::ios::binary;
184184
static inline const std::string extension = "binpack";
185185

186-
BinpackSfenInputParallelStream(int concurrency, std::string filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187-
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filename, openmode, skipPredicate)),
188-
m_filename(filename),
186+
BinpackSfenInputParallelStream(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate) :
187+
m_stream(std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(concurrency, filenames, openmode, cyclic, skipPredicate)),
188+
m_filenames(filenames),
189189
m_concurrency(concurrency),
190190
m_eof(false),
191191
m_cyclic(cyclic),
@@ -199,12 +199,6 @@ namespace training_data {
199199
auto v = m_stream->next();
200200
if (!v.has_value())
201201
{
202-
if (m_cyclic)
203-
{
204-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
205-
return m_stream->next();
206-
}
207-
208202
m_eof = true;
209203
return std::nullopt;
210204
}
@@ -217,32 +211,7 @@ namespace training_data {
217211
auto k = m_stream->fill(v, n);
218212
if (n != k)
219213
{
220-
if (m_cyclic)
221-
{
222-
m_stream = std::make_unique<binpack::CompressedTrainingDataEntryParallelReader>(m_concurrency, m_filename, openmode, m_skipPredicate);
223-
n -= k;
224-
k = m_stream->fill(v, n);
225-
if (k == 0)
226-
{
227-
// No data in the file
228-
m_eof = true;
229-
return;
230-
}
231-
else if (k == n)
232-
{
233-
// We're done
234-
return;
235-
}
236-
else
237-
{
238-
// We need to read again
239-
this->fill(v, n - k);
240-
}
241-
}
242-
else
243-
{
244-
m_eof = true;
245-
}
214+
m_eof = true;
246215
}
247216
}
248217

@@ -255,7 +224,7 @@ namespace training_data {
255224

256225
private:
257226
std::unique_ptr<binpack::CompressedTrainingDataEntryParallelReader> m_stream;
258-
std::string m_filename;
227+
std::vector<std::string> m_filenames;
259228
int m_concurrency;
260229
bool m_eof;
261230
bool m_cyclic;
@@ -272,13 +241,13 @@ namespace training_data {
272241
return nullptr;
273242
}
274243

275-
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::string& filename, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
244+
inline std::unique_ptr<BasicSfenInputStream> open_sfen_input_file_parallel(int concurrency, const std::vector<std::string>& filenames, bool cyclic, std::function<bool(const TrainingDataEntry&)> skipPredicate = nullptr)
276245
{
277246
// TODO (low priority): optimize and parallelize .bin reading.
278-
if (has_extension(filename, BinSfenInputStream::extension))
279-
return std::make_unique<BinSfenInputStream>(filename, cyclic, std::move(skipPredicate));
280-
else if (has_extension(filename, BinpackSfenInputParallelStream::extension))
281-
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filename, cyclic, std::move(skipPredicate));
247+
if (has_extension(filenames[0], BinSfenInputStream::extension))
248+
return std::make_unique<BinSfenInputStream>(filenames[0], cyclic, std::move(skipPredicate));
249+
else if (has_extension(filenames[0], BinpackSfenInputParallelStream::extension))
250+
return std::make_unique<BinpackSfenInputParallelStream>(concurrency, filenames, cyclic, std::move(skipPredicate));
282251

283252
return nullptr;
284253
}

nnue_dataset.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,18 @@ def get_fens(self):
6767
return strings
6868

6969
FenBatchPtr = ctypes.POINTER(FenBatch)
70-
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int param_index)
70+
# EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
7171
create_fen_batch_stream = dll.create_fen_batch_stream
7272
create_fen_batch_stream.restype = ctypes.c_void_p
73-
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
73+
create_fen_batch_stream.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
7474
destroy_fen_batch_stream = dll.destroy_fen_batch_stream
7575
destroy_fen_batch_stream.argtypes = [ctypes.c_void_p]
7676

77+
def make_fen_batch_stream(concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
78+
filenames_ = (ctypes.c_char_p * len(filenames))()
79+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
80+
return create_fen_batch_stream(concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
81+
7782
fetch_next_fen_batch = dll.fetch_next_fen_batch
7883
fetch_next_fen_batch.restype = FenBatchPtr
7984
fetch_next_fen_batch.argtypes = [ctypes.c_void_p]
@@ -103,9 +108,9 @@ def __init__(
103108
self.param_index = param_index
104109

105110
if batch_size:
106-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
111+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
107112
else:
108-
self.stream = create_fen_batch_stream(self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
113+
self.stream = make_fen_batch_stream(self.num_workers, [self.filename], cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
109114

110115
def __iter__(self):
111116
return self
@@ -131,7 +136,7 @@ def __init__(
131136
destroy_stream,
132137
fetch_next,
133138
destroy_part,
134-
filename,
139+
filenames,
135140
cyclic,
136141
num_workers,
137142
batch_size=None,
@@ -147,7 +152,7 @@ def __init__(
147152
self.destroy_stream = destroy_stream
148153
self.fetch_next = fetch_next
149154
self.destroy_part = destroy_part
150-
self.filename = filename.encode('utf-8')
155+
self.filenames = filenames
151156
self.cyclic = cyclic
152157
self.num_workers = num_workers
153158
self.batch_size = batch_size
@@ -158,9 +163,9 @@ def __init__(
158163
self.device = device
159164

160165
if batch_size:
161-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
166+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
162167
else:
163-
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filename, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
168+
self.stream = self.create_stream(self.feature_set, self.num_workers, self.filenames, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
164169

165170
def __iter__(self):
166171
return self
@@ -178,14 +183,19 @@ def __next__(self):
178183
def __del__(self):
179184
self.destroy_stream(self.stream)
180185

181-
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename, int batch_size, bool cyclic,
186+
# EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames, int batch_size, bool cyclic,
182187
# bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
183188
create_sparse_batch_stream = dll.create_sparse_batch_stream
184189
create_sparse_batch_stream.restype = ctypes.c_void_p
185-
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
190+
create_sparse_batch_stream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.POINTER(ctypes.c_char_p), ctypes.c_int, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_bool, ctypes.c_int, ctypes.c_int]
186191
destroy_sparse_batch_stream = dll.destroy_sparse_batch_stream
187192
destroy_sparse_batch_stream.argtypes = [ctypes.c_void_p]
188193

194+
def make_sparse_batch_stream(feature_set, concurrency, filenames, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index):
195+
filenames_ = (ctypes.c_char_p * len(filenames))()
196+
filenames_[:] = [filename.encode('utf-8') for filename in filenames]
197+
return create_sparse_batch_stream(feature_set, concurrency, len(filenames), filenames_, batch_size, cyclic, filtered, random_fen_skipping, wld_filtered, early_fen_skipping, param_index)
198+
189199
fetch_next_sparse_batch = dll.fetch_next_sparse_batch
190200
fetch_next_sparse_batch.restype = SparseBatchPtr
191201
fetch_next_sparse_batch.argtypes = [ctypes.c_void_p]
@@ -211,14 +221,14 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
211221
return b
212222

213223
class SparseBatchProvider(TrainingDataProvider):
214-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
224+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
215225
super(SparseBatchProvider, self).__init__(
216226
feature_set,
217-
create_sparse_batch_stream,
227+
make_sparse_batch_stream,
218228
destroy_sparse_batch_stream,
219229
fetch_next_sparse_batch,
220230
destroy_sparse_batch,
221-
filename,
231+
filenames,
222232
cyclic,
223233
num_workers,
224234
batch_size,
@@ -230,10 +240,10 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
230240
device)
231241

232242
class SparseBatchDataset(torch.utils.data.IterableDataset):
233-
def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
243+
def __init__(self, feature_set, filenames, batch_size, cyclic=True, num_workers=1, filtered=False, random_fen_skipping=0, wld_filtered=False, early_fen_skipping=-1, param_index=0, device='cpu'):
234244
super(SparseBatchDataset).__init__()
235245
self.feature_set = feature_set
236-
self.filename = filename
246+
self.filenames = filenames
237247
self.batch_size = batch_size
238248
self.cyclic = cyclic
239249
self.num_workers = num_workers
@@ -245,7 +255,7 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
245255
self.device = device
246256

247257
def __iter__(self):
248-
return SparseBatchProvider(self.feature_set, self.filename, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
258+
return SparseBatchProvider(self.feature_set, self.filenames, self.batch_size, cyclic=self.cyclic, num_workers=self.num_workers,
249259
filtered=self.filtered, random_fen_skipping=self.random_fen_skipping, wld_filtered=self.wld_filtered, early_fen_skipping = self.early_fen_skipping, param_index=self.param_index, device=self.device)
250260

251261
class FixedNumBatchesDataset(Dataset):

0 commit comments

Comments
 (0)