@@ -67,13 +67,18 @@ def get_fens(self):
6767 return strings
6868
6969FenBatchPtr = ctypes .POINTER (FenBatch )
70- # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, const char* filename , int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int param_index)
70+ # EXPORT FenBatchStream* CDECL create_fen_batch_stream(int concurrency, int num_files, const char* const* filenames , int batch_size, bool cyclic, bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping , int param_index)
7171create_fen_batch_stream = dll .create_fen_batch_stream
7272create_fen_batch_stream .restype = ctypes .c_void_p
73- create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
73+ create_fen_batch_stream .argtypes = [ctypes .c_int , ctypes .c_int , ctypes . POINTER ( ctypes . c_char_p ) , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
7474destroy_fen_batch_stream = dll .destroy_fen_batch_stream
7575destroy_fen_batch_stream .argtypes = [ctypes .c_void_p ]
7676
77+ def make_fen_batch_stream (concurrency , filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index ):
78+ filenames_ = (ctypes .c_char_p * len (filenames ))()
79+ filenames_ [:] = [filename .encode ('utf-8' ) for filename in filenames ]
80+ return create_fen_batch_stream (concurrency , len (filenames ), filenames_ , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
81+
7782fetch_next_fen_batch = dll .fetch_next_fen_batch
7883fetch_next_fen_batch .restype = FenBatchPtr
7984fetch_next_fen_batch .argtypes = [ctypes .c_void_p ]
@@ -103,9 +108,9 @@ def __init__(
103108 self .param_index = param_index
104109
105110 if batch_size :
106- self .stream = create_fen_batch_stream (self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
111+ self .stream = make_fen_batch_stream (self .num_workers , [ self .filename ] , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
107112 else :
108- self .stream = create_fen_batch_stream (self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
113+ self .stream = make_fen_batch_stream (self .num_workers , [ self .filename ] , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
109114
110115 def __iter__ (self ):
111116 return self
@@ -131,7 +136,7 @@ def __init__(
131136 destroy_stream ,
132137 fetch_next ,
133138 destroy_part ,
134- filename ,
139+ filenames ,
135140 cyclic ,
136141 num_workers ,
137142 batch_size = None ,
@@ -147,7 +152,7 @@ def __init__(
147152 self .destroy_stream = destroy_stream
148153 self .fetch_next = fetch_next
149154 self .destroy_part = destroy_part
150- self .filename = filename . encode ( 'utf-8' )
155+ self .filenames = filenames
151156 self .cyclic = cyclic
152157 self .num_workers = num_workers
153158 self .batch_size = batch_size
@@ -158,9 +163,9 @@ def __init__(
158163 self .device = device
159164
160165 if batch_size :
161- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
166+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
162167 else :
163- self .stream = self .create_stream (self .feature_set , self .num_workers , self .filename , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
168+ self .stream = self .create_stream (self .feature_set , self .num_workers , self .filenames , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
164169
165170 def __iter__ (self ):
166171 return self
@@ -178,14 +183,19 @@ def __next__(self):
178183 def __del__ (self ):
179184 self .destroy_stream (self .stream )
180185
181- # EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, const char* filename , int batch_size, bool cyclic,
186+ # EXPORT Stream<SparseBatch>* CDECL create_sparse_batch_stream(const char* feature_set_c, int concurrency, int num_files, const char* const* filenames , int batch_size, bool cyclic,
182187# bool filtered, int random_fen_skipping, bool wld_filtered, int early_fen_skipping, int param_index)
183188create_sparse_batch_stream = dll .create_sparse_batch_stream
184189create_sparse_batch_stream .restype = ctypes .c_void_p
185- create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_char_p , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
190+ create_sparse_batch_stream .argtypes = [ctypes .c_char_p , ctypes .c_int , ctypes .c_int , ctypes . POINTER ( ctypes . c_char_p ) , ctypes .c_int , ctypes .c_bool , ctypes .c_bool , ctypes .c_int , ctypes .c_bool , ctypes .c_int , ctypes .c_int ]
186191destroy_sparse_batch_stream = dll .destroy_sparse_batch_stream
187192destroy_sparse_batch_stream .argtypes = [ctypes .c_void_p ]
188193
194+ def make_sparse_batch_stream (feature_set , concurrency , filenames , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index ):
195+ filenames_ = (ctypes .c_char_p * len (filenames ))()
196+ filenames_ [:] = [filename .encode ('utf-8' ) for filename in filenames ]
197+ return create_sparse_batch_stream (feature_set , concurrency , len (filenames ), filenames_ , batch_size , cyclic , filtered , random_fen_skipping , wld_filtered , early_fen_skipping , param_index )
198+
189199fetch_next_sparse_batch = dll .fetch_next_sparse_batch
190200fetch_next_sparse_batch .restype = SparseBatchPtr
191201fetch_next_sparse_batch .argtypes = [ctypes .c_void_p ]
@@ -211,14 +221,14 @@ def make_sparse_batch_from_fens(feature_set, fens, scores, plies, results):
211221 return b
212222
213223class SparseBatchProvider (TrainingDataProvider ):
214- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
224+ def __init__ (self , feature_set , filenames , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
215225 super (SparseBatchProvider , self ).__init__ (
216226 feature_set ,
217- create_sparse_batch_stream ,
227+ make_sparse_batch_stream ,
218228 destroy_sparse_batch_stream ,
219229 fetch_next_sparse_batch ,
220230 destroy_sparse_batch ,
221- filename ,
231+ filenames ,
222232 cyclic ,
223233 num_workers ,
224234 batch_size ,
@@ -230,10 +240,10 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
230240 device )
231241
232242class SparseBatchDataset (torch .utils .data .IterableDataset ):
233- def __init__ (self , feature_set , filename , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
243+ def __init__ (self , feature_set , filenames , batch_size , cyclic = True , num_workers = 1 , filtered = False , random_fen_skipping = 0 , wld_filtered = False , early_fen_skipping = - 1 , param_index = 0 , device = 'cpu' ):
234244 super (SparseBatchDataset ).__init__ ()
235245 self .feature_set = feature_set
236- self .filename = filename
246+ self .filenames = filenames
237247 self .batch_size = batch_size
238248 self .cyclic = cyclic
239249 self .num_workers = num_workers
@@ -245,7 +255,7 @@ def __init__(self, feature_set, filename, batch_size, cyclic=True, num_workers=1
245255 self .device = device
246256
247257 def __iter__ (self ):
248- return SparseBatchProvider (self .feature_set , self .filename , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers ,
258+ return SparseBatchProvider (self .feature_set , self .filenames , self .batch_size , cyclic = self .cyclic , num_workers = self .num_workers ,
249259 filtered = self .filtered , random_fen_skipping = self .random_fen_skipping , wld_filtered = self .wld_filtered , early_fen_skipping = self .early_fen_skipping , param_index = self .param_index , device = self .device )
250260
251261class FixedNumBatchesDataset (Dataset ):
0 commit comments