Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Commit 13b7fab

Browse files
committed
Added batch training option for ANNEvaluator
1 parent f2d1ec8 commit 13b7fab

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

ann/ann_evaluator.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "ann_evaluator.h"
1919

2020
#include <fstream>
21+
#include <set>
2122

2223
#include "consts.h"
2324

@@ -77,6 +78,16 @@ float ANNEvaluator::Train(const NNMatrixRM &pred, EvalNet::Activations &act, con
7778
return ((pred - targets).array() * (pred - targets).array()).sum() / targets.rows();
7879
}
7980

81+
float ANNEvaluator::Train(const NNMatrixRM &positions, const NNMatrixRM &targets)
82+
{
83+
// in this version (where we don't have predictions already) we can simply call ANN's TrainGDM
84+
float e = m_mainAnn.TrainGDM(positions, targets, 1.0f, 1.0f);
85+
86+
InvalidateCache();
87+
88+
return e;
89+
}
90+
8091
void ANNEvaluator::EvaluateForWhiteMatrix(const NNMatrixRM &x, NNMatrixRM &pred, EvalNet::Activations &act)
8192
{
8293
if (act.act.size() == 0)
@@ -136,6 +147,31 @@ Score ANNEvaluator::EvaluateForWhiteImpl(Board &b, Score lowerBound, Score upper
136147
return *hashResult;
137148
}
138149

150+
/*
151+
static uint64_t hits = 0;
152+
static uint64_t misses = 0;
153+
static std::set<size_t> seen;
154+
155+
Board::SlowFeatures sf;
156+
b.GetSlowFeatures(sf);
157+
size_t hash = sf.Hash();
158+
159+
if (seen.find(hash) != seen.end())
160+
{
161+
++hits;
162+
}
163+
else
164+
{
165+
++misses;
166+
seen.insert(hash);
167+
}
168+
169+
if (((misses + hits) % 1000000) == 0)
170+
{
171+
std::cout << hits << "/" << misses << " (" << (static_cast<float>(hits) / (hits + misses)) << ")" << std::endl;
172+
}
173+
*/
174+
139175
FeaturesConv::ConvertBoardToNN(b, m_convTmp);
140176

141177
// we have to map every time because the vector's buffer could have moved

ann/ann_evaluator.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ class ANNEvaluator : public EvaluatorIface
7070

7171
float Train(const NNMatrixRM &pred, EvalNet::Activations &act, const NNMatrixRM &targets);
7272

73+
float Train(const NNMatrixRM &positions, const NNMatrixRM &targets);
74+
7375
// this is a special bulk evaluate for training
7476
void EvaluateForWhiteMatrix(const NNMatrixRM &x, NNMatrixRM &pred, EvalNet::Activations &act);
7577

0 commit comments

Comments
 (0)