Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Commit 97333ed

Browse files
committed
Revamped learning code
1 parent 81b5560 commit 97333ed

File tree

8 files changed

+301
-202
lines changed

8 files changed

+301
-202
lines changed

ann/ann.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,8 @@ class FCANN
172172

173173
void UpdateWeightSemiSparse_();
174174

175+
void InitializeOptimizationState_();
176+
175177
// this is used to ensure network stability
176178
constexpr static FP MAX_WEIGHT = 1000.0f;
177179

ann/ann_evaluator.cpp

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,37 +60,31 @@ void ANNEvaluator::Deserialize(std::istream &is)
6060
InvalidateCache();
6161
}
6262

63-
void ANNEvaluator::Train(const std::vector<std::string> &positions, const NNMatrixRM &y, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions, float learningRate)
63+
float ANNEvaluator::Train(const NNMatrixRM &pred, EvalNet::Activations &act, const NNMatrixRM &targets)
6464
{
65-
auto x = BoardsToFeatureRepresentation_(positions, featureDescriptions);
66-
67-
NNMatrixRM predictions;
68-
EvalNet::Activations act;
69-
70-
m_mainAnn.InitializeActivations(act);
71-
72-
predictions = m_mainAnn.ForwardPropagate(x, act);
73-
74-
NNMatrixRM errorsDerivative = ComputeErrorDerivatives_(predictions, y, act.actIn[act.actIn.size() - 1], 1.0f, 1.0f);
65+
NNMatrixRM errorsDerivative = ComputeErrorDerivatives_(pred, targets, act.actIn[act.actIn.size() - 1], 1.0f, 1.0f);
7566

7667
EvalNet::Gradients grad;
7768

7869
m_mainAnn.InitializeGradients(grad);
7970

8071
m_mainAnn.BackwardPropagateComputeGrad(errorsDerivative, act, grad);
8172

82-
m_mainAnn.ApplyWeightUpdates(grad, learningRate, 0.0f);
73+
m_mainAnn.ApplyWeightUpdates(grad, 1.0f, 0.0f);
8374

8475
InvalidateCache();
76+
77+
return ((pred - targets).array() * (pred - targets).array()).sum() / targets.rows();
8578
}
8679

87-
void ANNEvaluator::TrainLoop(const std::vector<std::string> &positions, const NNMatrixRM &y, int64_t epochs, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions)
80+
void ANNEvaluator::EvaluateForWhiteMatrix(const NNMatrixRM &x, NNMatrixRM &pred, EvalNet::Activations &act)
8881
{
89-
auto x = BoardsToFeatureRepresentation_(positions, featureDescriptions);
90-
91-
LearnAnn::TrainANN(x, y, m_mainAnn, epochs);
82+
if (act.act.size() == 0)
83+
{
84+
m_mainAnn.InitializeActivations(act);
85+
}
9286

93-
InvalidateCache();
87+
pred = m_mainAnn.ForwardPropagate(x, act);
9488
}
9589

9690
void ANNEvaluator::TrainBounds(const std::vector<std::string> &positions, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions, float learningRate)

ann/ann_evaluator.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,10 @@ class ANNEvaluator : public EvaluatorIface
6868

6969
void Deserialize(std::istream &is);
7070

71-
void Train(const std::vector<std::string> &positions, const NNMatrixRM &y, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions, float learningRate);
71+
float Train(const NNMatrixRM &pred, EvalNet::Activations &act, const NNMatrixRM &targets);
7272

73-
void TrainLoop(const std::vector<std::string> &positions, const NNMatrixRM &y, int64_t epochs, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions);
73+
// this is a special bulk evaluate for training
74+
void EvaluateForWhiteMatrix(const NNMatrixRM &x, NNMatrixRM &pred, EvalNet::Activations &act);
7475

7576
void TrainBounds(const std::vector<std::string> &positions, const std::vector<FeaturesConv::FeatureDescription> &featureDescriptions, float learningRate);
7677

ann/ann_impl.h

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,14 @@ FCANN<ACTF, ACTFLast>::FCANN(
120120
// we have a fully connected layer
121121
m_params.weightMasks.push_back(NNMatrix::Ones(in_size, out_size));
122122
}
123-
124-
m_params.outputBiasLastUpdate.push_back(NNVector::Zero(out_size));
125-
m_params.weightsLastUpdate.push_back(NNMatrix::Zero(in_size, out_size));
126-
127-
m_params.outputBiasEg2.push_back(NNVector::Zero(out_size));
128-
m_params.weightsEg2.push_back(NNMatrix::Zero(in_size, out_size));
129-
130-
m_params.outputBiasRMSd2.push_back(NNVector::Zero(out_size));
131-
m_params.weightsRMSd2.push_back(NNMatrix::Zero(in_size, out_size));
132123
}
133124

134125
m_params.evalTmp.resize(hiddenLayers.size() + 2);
135126
m_params.evalSingleTmp.resize(hiddenLayers.size() + 2);
136127

137128
UpdateWeightMasksRegions_();
138129
UpdateWeightSemiSparse_();
130+
InitializeOptimizationState_();
139131
}
140132

141133
template <ActivationFunc ACTF, ActivationFunc ACTFLast>
@@ -400,20 +392,21 @@ float FCANN<ACTF, ACTFLast>::TrainGDM(const MatrixBase<Derived1> &x, const Matri
400392
}
401393

402394
template <ActivationFunc ACTF, ActivationFunc ACTFLast>
403-
void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float learningRate, float reg)
395+
void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*learningRate*/, float reg)
404396
{
405397
assert(grad.weightGradients.size() == m_params.weights.size());
406398
assert(grad.biasGradients.size() == m_params.outputBias.size());
407399
assert(grad.weightGradients.size() == grad.biasGradients.size());
408400

401+
/* // for SGD + M
409402
m_params.weightsLastUpdate.resize(m_params.weights.size());
410403
m_params.outputBiasLastUpdate.resize(m_params.outputBias.size());
404+
*/
411405

412-
m_params.weightsEg2.resize(m_params.weights.size());
413-
m_params.outputBiasEg2.resize(m_params.outputBias.size());
414-
415-
m_params.weightsRMSd2.resize(m_params.weights.size());
416-
m_params.outputBiasRMSd2.resize(m_params.outputBias.size());
406+
if (m_params.weightsEg2.size() != m_params.weights.size())
407+
{
408+
InitializeOptimizationState_();
409+
}
417410

418411
for (size_t layer = 0; layer < m_params.weights.size(); ++layer)
419412
{
@@ -484,8 +477,8 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float lear
484477
#endif
485478

486479
// update Eg2 (ADADELTA)
487-
float decay = 0.99f;
488-
float e = 1e-8f;
480+
float decay = 0.95f;
481+
float e = 1e-6f;
489482
weightsEg2Block.array() *= decay;
490483
weightsEg2Block.array() += (weightsGradientsBlock.array() * weightsGradientsBlock.array()) * (1.0f - decay);
491484
biasEg2Block.array() *= decay;
@@ -498,9 +491,9 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float lear
498491
//NNMatrix weightDelta = -weightsGradientsBlock.array() * learningRate /*+ weightReg.array()*/;
499492
//NNVector biasDelta = -biasGradientsBlock.array() * learningRate;
500493

501-
weightsBlock += weightDelta * learningRate;
494+
weightsBlock += weightDelta;
502495
weightsBlock.array() *= weightMaskBlock.array();
503-
biasBlock += biasDelta * learningRate;
496+
biasBlock += biasDelta;
504497

505498
FP weightMax = std::max(std::max(weightsBlock.maxCoeff(), -weightsBlock.minCoeff()), std::max(biasBlock.maxCoeff(), -biasBlock.minCoeff()));
506499
if (weightMax > MAX_WEIGHT)
@@ -779,6 +772,27 @@ void FCANN<ACTF, ACTFLast>::UpdateWeightSemiSparse_()
779772
m_params.weightsSemiSparseCurrent = true;
780773
}
781774

775+
776+
template <ActivationFunc ACTF, ActivationFunc ACTFLast>
777+
void FCANN<ACTF, ACTFLast>::InitializeOptimizationState_()
778+
{
779+
m_params.weightsEg2.resize(m_params.weights.size());
780+
m_params.outputBiasEg2.resize(m_params.outputBias.size());
781+
782+
m_params.weightsRMSd2.resize(m_params.weights.size());
783+
m_params.outputBiasRMSd2.resize(m_params.outputBias.size());
784+
785+
for (size_t i = 0; i < m_params.weights.size(); ++i)
786+
{
787+
m_params.outputBiasEg2[i] = NNVector::Zero(m_params.outputBias[i].cols());
788+
m_params.weightsEg2[i] = NNMatrix::Zero(m_params.weights[i].rows(), m_params.weights[i].cols());
789+
790+
m_params.outputBiasRMSd2[i] = NNVector::Zero(m_params.outputBias[i].cols());
791+
m_params.weightsRMSd2[i] = NNMatrix::Zero(m_params.weights[i].rows(), m_params.weights[i].cols());
792+
}
793+
}
794+
795+
782796
/* serialization format:
783797
* numLayers
784798
* for each layer:

ann/features_conv.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,26 @@ void PushSquareFeatures(std::vector<T> &ret, const Board &/*board*/, AttackMaps
291291
{
292292
PushPosFloat(ret, sq, atkMaps.whiteCtrl[sq], group);
293293
PushPosFloat(ret, sq, atkMaps.blackCtrl[sq], group + 1);
294+
295+
/*
296+
PieceType pt = board.GetPieceAtSquare(sq);
297+
298+
if (pt == EMPTY)
299+
{
300+
PushPosFloat(ret, sq, 0.0f, group + 2);
301+
PushPosFloat(ret, sq, 0.0f, group + 3);
302+
}
303+
else if (GetColor(pt) == WHITE)
304+
{
305+
PushPosFloat(ret, sq, NormalizeCount(SEE::SEE_MAT[board.GetPieceAtSquare(sq)], SEE::SEE_MAT[WK]), group + 2);
306+
PushPosFloat(ret, sq, 0.0f, group + 3);
307+
}
308+
else if (GetColor(pt) == BLACK)
309+
{
310+
PushPosFloat(ret, sq, 0.0f, group + 2);
311+
PushPosFloat(ret, sq, NormalizeCount(SEE::SEE_MAT[board.GetPieceAtSquare(sq)], SEE::SEE_MAT[WK]), group + 3);
312+
}
313+
*/
294314
}
295315

296316
group += 2;

ann/features_conv.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ struct FeatureDescription
7979
template <typename T>
8080
void ConvertBoardToNN(Board &board, std::vector<T> &ret);
8181

82+
inline int64_t GetNumFeatures()
83+
{
84+
Board b;
85+
86+
std::vector<FeaturesConv::FeatureDescription> ret;
87+
FeaturesConv::ConvertBoardToNN(b, ret);
88+
89+
return static_cast<int64_t>(ret.size());
90+
}
91+
8292
// additional info for conversion
8393
struct ConvertMovesInfo
8494
{

0 commit comments

Comments
 (0)