Skip to content

Commit 46d85de

Browse files
committed
Tools transform option for filtering data for training nn-335a9b2d8a80.nnue
Append hash of first master net trained with filter method Hardcode depth 6 and remove option to set depth Underscores for consistency Filter out standard startpos positions too
1 parent 399d556 commit 46d85de

File tree

1 file changed

+225
-1
lines changed

1 file changed

+225
-1
lines changed

src/tools/transform.cpp

Lines changed: 225 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ namespace Stockfish::Tools
6363
}
6464
};
6565

66+
struct FilterParams
67+
{
68+
std::string input_filename = "in.binpack";
69+
std::string output_filename = "out.binpack";
70+
bool debug_print = false;
71+
};
72+
6673
[[nodiscard]] std::int16_t nudge(NudgedStaticParams& params, std::int16_t static_eval_i16, std::int16_t deep_eval_i16)
6774
{
6875
auto saturate_i32_to_i16 = [](int v) {
@@ -502,11 +509,228 @@ namespace Stockfish::Tools
502509
do_rescore(params);
503510
}
504511

512+
void do_filter_data_335a9b2d8a80(FilterParams& params)
513+
{
514+
// TODO: Use SfenReader once it works correctly in sequential mode. See issue #271
515+
auto in = Tools::open_sfen_input_file(params.input_filename);
516+
auto readsome = [&in, mutex = std::mutex{}](int n) mutable -> PSVector {
517+
518+
PSVector psv;
519+
psv.reserve(n);
520+
521+
std::unique_lock lock(mutex);
522+
523+
for (int i = 0; i < n; ++i)
524+
{
525+
auto ps_opt = in->next();
526+
if (ps_opt.has_value())
527+
{
528+
psv.emplace_back(*ps_opt);
529+
}
530+
else
531+
{
532+
break;
533+
}
534+
}
535+
536+
return psv;
537+
};
538+
539+
auto sfen_format = SfenOutputType::Binpack;
540+
541+
auto out = SfenWriter(
542+
params.output_filename,
543+
Threads.size(),
544+
std::numeric_limits<std::uint64_t>::max(),
545+
sfen_format);
546+
547+
// About Search::Limits
548+
// Be careful because this member variable is global and affects other threads.
549+
auto& limits = Search::Limits;
550+
551+
// Make the search equivalent to the "go infinite" command. (Because it is troublesome if time management is done)
552+
limits.infinite = true;
553+
554+
// Since PV is an obstacle when displayed, erase it.
555+
limits.silent = true;
556+
557+
// If you use this, it will be compared with the accumulated nodes of each thread. Therefore, do not use it.
558+
limits.nodes = 0;
559+
560+
// depth is also processed by the one passed as an argument of Tools::search().
561+
limits.depth = 0;
562+
563+
std::atomic<std::uint64_t> num_processed = 0;
564+
std::atomic<std::uint64_t> num_standard_startpos = 0;
565+
std::atomic<std::uint64_t> num_position_in_check = 0;
566+
std::atomic<std::uint64_t> num_move_already_is_capture = 0;
567+
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap0 = 0;
568+
std::atomic<std::uint64_t> num_capture_or_promo_skipped_multipv_cap1 = 0;
569+
570+
Threads.execute_with_workers([&](auto& th){
571+
Position& pos = th.rootPos;
572+
StateInfo si;
573+
const bool frc = Options["UCI_Chess960"];
574+
575+
const bool debug_print = params.debug_print;
576+
for (;;)
577+
{
578+
PSVector psv = readsome(5000);
579+
if (psv.empty())
580+
break;
581+
582+
for(auto& ps : psv)
583+
{
584+
pos.set_from_packed_sfen(ps.sfen, &si, &th, frc);
585+
bool should_skip_position = false;
586+
if (pos.checkers()) {
587+
// Skip if in check
588+
if (debug_print) {
589+
sync_cout << "[debug] " << pos.fen() << sync_endl
590+
<< "[debug] Position is in check" << sync_endl
591+
<< "[debug]" << sync_endl;
592+
}
593+
num_position_in_check.fetch_add(1);
594+
should_skip_position = true;
595+
} else if (pos.capture_or_promotion((Stockfish::Move)ps.move)) {
596+
// Skip if the provided move is already a capture or promotion
597+
if (debug_print) {
598+
sync_cout << "[debug] " << pos.fen() << sync_endl
599+
<< "[debug] Provided move is capture or promo: "
600+
<< UCI::move((Stockfish::Move)ps.move, false)
601+
<< sync_endl
602+
<< "[debug]" << sync_endl;
603+
}
604+
num_move_already_is_capture.fetch_add(1);
605+
should_skip_position = true;
606+
} else if (pos.fen() == "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1") {
607+
num_standard_startpos.fetch_add(1);
608+
should_skip_position = true;
609+
} else {
610+
auto [search_val, pvs] = Search::search(pos, 6, 2);
611+
if (!pvs.empty() && th.rootMoves.size() > 0) {
612+
auto best_move = th.rootMoves[0].pv[0];
613+
bool more_than_one_valid_move = th.rootMoves.size() > 1;
614+
if (debug_print) {
615+
sync_cout << "[debug] " << pos.fen() << sync_endl;
616+
sync_cout << "[debug] Main PV move: "
617+
<< UCI::move(best_move, false) << " "
618+
<< th.rootMoves[0].score << " " << sync_endl;
619+
if (more_than_one_valid_move) {
620+
sync_cout << "[debug] 2nd PV move: "
621+
<< UCI::move(th.rootMoves[1].pv[0], false) << " "
622+
<< th.rootMoves[1].score << " " << sync_endl;
623+
} else {
624+
sync_cout << "[debug] The only valid move" << sync_endl;
625+
}
626+
}
627+
if (pos.capture_or_promotion(best_move)) {
628+
// skip if multipv 1st line bestmove is a capture or promo
629+
if (debug_print) {
630+
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
631+
<< sync_endl
632+
<< "[debug] 1st best move at depth 6 multipv 2" << sync_endl
633+
<< "[debug]" << sync_endl;
634+
}
635+
num_capture_or_promo_skipped_multipv_cap0.fetch_add(1);
636+
should_skip_position = true;
637+
} else if (more_than_one_valid_move && pos.capture_or_promotion(th.rootMoves[1].pv[0])) {
638+
// skip if multipv 2nd line bestmove is a capture or promo
639+
if (debug_print) {
640+
sync_cout << "[debug] Move is capture or promo: " << UCI::move(best_move, false)
641+
<< sync_endl
642+
<< "[debug] 2nd best move at depth 6 multipv 2" << sync_endl
643+
<< "[debug]" << sync_endl;
644+
}
645+
num_capture_or_promo_skipped_multipv_cap1.fetch_add(1);
646+
should_skip_position = true;
647+
}
648+
}
649+
}
650+
pos.sfen_pack(ps.sfen, false);
651+
// nnue-pytorch training data loader skips positions with score VALUE_NONE
652+
if (should_skip_position)
653+
ps.score = 32002; // VALUE_NONE
654+
ps.padding = 0;
655+
656+
out.write(th.id(), ps);
657+
658+
auto p = num_processed.fetch_add(1) + 1;
659+
if (p % 10000 == 0) {
660+
auto c = num_position_in_check.load();
661+
auto a = num_move_already_is_capture.load();
662+
auto s = num_standard_startpos.load();
663+
auto multipv_cap0 = num_capture_or_promo_skipped_multipv_cap0.load();
664+
auto multipv_cap1 = num_capture_or_promo_skipped_multipv_cap1.load();
665+
sync_cout << "Processed " << p << " positions. Skipped " << (c + a + s + multipv_cap0 + multipv_cap1) << " positions."
666+
<< sync_endl
667+
<< " Static filter: " << (a + c + s)
668+
<< " (capture or promo: " << a << ", in check: " << c << ", startpos: " << s << ")"
669+
<< sync_endl
670+
<< " MultiPV filter: " << (multipv_cap0 + multipv_cap1)
671+
<< " (cap0: " << multipv_cap0 << ", cap1: " << multipv_cap1 << ")"
672+
<< " depth 6 multipv 2" << sync_endl;
673+
}
674+
}
675+
}
676+
});
677+
Threads.wait_for_workers_finished();
678+
679+
std::cout << "Finished.\n";
680+
}
681+
682+
void do_filter_335a9b2d8a80(FilterParams& params)
683+
{
684+
if (ends_with(params.input_filename, ".binpack"))
685+
{
686+
do_filter_data_335a9b2d8a80(params);
687+
}
688+
else
689+
{
690+
std::cerr << "Invalid input file type.\n";
691+
}
692+
}
693+
694+
void filter_335a9b2d8a80(std::istringstream& is)
695+
{
696+
FilterParams params{};
697+
698+
while(true)
699+
{
700+
std::string token;
701+
is >> token;
702+
703+
if (token == "")
704+
break;
705+
706+
else if (token == "input_file")
707+
is >> params.input_filename;
708+
else if (token == "output_file")
709+
is >> params.output_filename;
710+
else if (token == "debug_print")
711+
is >> params.debug_print;
712+
else
713+
{
714+
std::cout << "ERROR: Unknown option " << token << ". Exiting...\n";
715+
return;
716+
}
717+
}
718+
719+
std::cout << "Performing transform filter_335a9b2d8a80 with parameters:\n";
720+
std::cout << "input_file : " << params.input_filename << '\n';
721+
std::cout << "output_file : " << params.output_filename << '\n';
722+
std::cout << "debug_print : " << params.debug_print << '\n';
723+
std::cout << '\n';
724+
725+
do_filter_335a9b2d8a80(params);
726+
}
727+
505728
void transform(std::istringstream& is)
506729
{
507730
const std::map<std::string, CommandFunc> subcommands = {
508731
{ "nudged_static", &nudged_static },
509-
{ "rescore", &rescore }
732+
{ "rescore", &rescore },
733+
{ "filter_335a9b2d8a80", &filter_335a9b2d8a80 }
510734
};
511735

512736
Eval::NNUE::init();

0 commit comments

Comments
 (0)