@@ -258,6 +258,93 @@ struct HalfKAv2Factorized {
258258 }
259259};
260260
261+ // ksq must not be oriented
262+ static Square orient_flip_2 (Color color, Square sq, Square ksq)
263+ {
264+ bool h = ksq.file () < fileE;
265+ if (color == Color::Black)
266+ sq = sq.flippedVertically ();
267+ if (h)
268+ sq = sq.flippedHorizontally ();
269+ return sq;
270+ }
271+
272+ struct HalfKAv2_hm {
273+ static constexpr int NUM_SQ = 64 ;
274+ static constexpr int NUM_PT = 11 ;
275+ static constexpr int NUM_PLANES = NUM_SQ * NUM_PT;
276+ static constexpr int INPUTS = NUM_PLANES * NUM_SQ / 2 ;
277+
278+ static constexpr int MAX_ACTIVE_FEATURES = 32 ;
279+
280+ static constexpr int KingBuckets[64 ] = {
281+ -1 , -1 , -1 , -1 , 31 , 30 , 29 , 28 ,
282+ -1 , -1 , -1 , -1 , 27 , 26 , 25 , 24 ,
283+ -1 , -1 , -1 , -1 , 23 , 22 , 21 , 20 ,
284+ -1 , -1 , -1 , -1 , 19 , 18 , 17 , 16 ,
285+ -1 , -1 , -1 , -1 , 15 , 14 , 13 , 12 ,
286+ -1 , -1 , -1 , -1 , 11 , 10 , 9 , 8 ,
287+ -1 , -1 , -1 , -1 , 7 , 6 , 5 , 4 ,
288+ -1 , -1 , -1 , -1 , 3 , 2 , 1 , 0
289+ };
290+
291+ static int feature_index (Color color, Square ksq, Square sq, Piece p)
292+ {
293+ Square o_ksq = orient_flip_2 (color, ksq, ksq);
294+ auto p_idx = static_cast <int >(p.type ()) * 2 + (p.color () != color);
295+ if (p_idx == 11 )
296+ --p_idx; // pack the opposite king into the same NUM_SQ * NUM_SQ
297+ return static_cast <int >(orient_flip_2 (color, sq, ksq)) + p_idx * NUM_SQ + KingBuckets[static_cast <int >(o_ksq)] * NUM_PLANES;
298+ }
299+
300+ static std::pair<int , int > fill_features_sparse (const TrainingDataEntry& e, int * features, float * values, Color color)
301+ {
302+ auto & pos = e.pos ;
303+ auto pieces = pos.piecesBB ();
304+ auto ksq = pos.kingSquare (color);
305+
306+ int j = 0 ;
307+ for (Square sq : pieces)
308+ {
309+ auto p = pos.pieceAt (sq);
310+ values[j] = 1 .0f ;
311+ features[j] = feature_index (color, ksq, sq, p);
312+ ++j;
313+ }
314+
315+ return { j, INPUTS };
316+ }
317+ };
318+
319+ struct HalfKAv2_hmFactorized {
320+ // Factorized features
321+ static constexpr int PIECE_INPUTS = HalfKAv2_hm::NUM_SQ * HalfKAv2_hm::NUM_PT;
322+ static constexpr int INPUTS = HalfKAv2_hm::INPUTS + PIECE_INPUTS;
323+
324+ static constexpr int MAX_PIECE_FEATURES = 32 ;
325+ static constexpr int MAX_ACTIVE_FEATURES = HalfKAv2_hm::MAX_ACTIVE_FEATURES + MAX_PIECE_FEATURES;
326+
327+ static std::pair<int , int > fill_features_sparse (const TrainingDataEntry& e, int * features, float * values, Color color)
328+ {
329+ const auto [start_j, offset] = HalfKAv2_hm::fill_features_sparse (e, features, values, color);
330+ auto & pos = e.pos ;
331+ auto pieces = pos.piecesBB ();
332+ auto ksq = pos.kingSquare (color);
333+
334+ int j = start_j;
335+ for (Square sq : pieces)
336+ {
337+ auto p = pos.pieceAt (sq);
338+ auto p_idx = static_cast <int >(p.type ()) * 2 + (p.color () != color);
339+ values[j] = 1 .0f ;
340+ features[j] = offset + (p_idx * HalfKAv2_hm::NUM_SQ) + static_cast <int >(orient_flip_2 (color, sq, ksq));
341+ ++j;
342+ }
343+
344+ return { j, INPUTS };
345+ }
346+ };
347+
261348template <typename T, typename ... Ts>
262349struct FeatureSet
263350{
@@ -797,6 +884,14 @@ extern "C" {
797884 {
798885 return new SparseBatch (FeatureSet<HalfKAv2Factorized>{}, entries);
799886 }
887+ else if (feature_set == " HalfKAv2_hm" )
888+ {
889+ return new SparseBatch (FeatureSet<HalfKAv2_hm>{}, entries);
890+ }
891+ else if (feature_set == " HalfKAv2_hm^" )
892+ {
893+ return new SparseBatch (FeatureSet<HalfKAv2_hmFactorized>{}, entries);
894+ }
800895 fprintf (stderr, " Unknown feature_set %s\n " , feature_set_c);
801896 return nullptr ;
802897 }
@@ -842,6 +937,14 @@ extern "C" {
842937 {
843938 return new FeaturedBatchStream<FeatureSet<HalfKAv2Factorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
844939 }
940+ else if (feature_set == " HalfKAv2_hm" )
941+ {
942+ return new FeaturedBatchStream<FeatureSet<HalfKAv2_hm>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
943+ }
944+ else if (feature_set == " HalfKAv2_hm^" )
945+ {
946+ return new FeaturedBatchStream<FeatureSet<HalfKAv2_hmFactorized>, SparseBatch>(concurrency, filename, batch_size, cyclic, skipPredicate);
947+ }
845948 fprintf (stderr, " Unknown feature_set %s\n " , feature_set_c);
846949 return nullptr ;
847950 }
0 commit comments