Skip to content
This repository was archived by the owner on May 21, 2022. It is now read-only.

Commit 9cb95cd

Browse files
committed
Enabled easier switching between ANN update rules
1 parent 13b7fab commit 9cb95cd

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

ann/ann_impl.h

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,15 @@
3131
#include "omp_scoped_thread_limiter.h"
3232
#include "random_device.h"
3333

34+
#define SGDM
35+
//#define ADADELTA
36+
37+
#if defined(SGDM) && defined(ADADELTA)
38+
#error Only select one training method!
39+
#elif !defined(SGDM) && !defined(ADADELTA)
40+
#error Must select one training method!
41+
#endif
42+
3443
inline void EnableNanInterrupt()
3544
{
3645
_MM_SET_EXCEPTION_MASK(_MM_GET_EXCEPTION_MASK() & ~_MM_MASK_INVALID);
@@ -398,11 +407,6 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
398407
assert(grad.biasGradients.size() == m_params.outputBias.size());
399408
assert(grad.weightGradients.size() == grad.biasGradients.size());
400409

401-
/* // for SGD + M
402-
m_params.weightsLastUpdate.resize(m_params.weights.size());
403-
m_params.outputBiasLastUpdate.resize(m_params.outputBias.size());
404-
*/
405-
406410
if (m_params.weightsEg2.size() != m_params.weights.size())
407411
{
408412
InitializeOptimizationState_();
@@ -428,13 +432,19 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
428432
auto weightsGradientsBlock = grad.weightGradients[layer].block(0, begin, inSize, numCols);
429433
auto biasGradientsBlock = grad.biasGradients[layer].block(0, begin, 1, numCols);
430434

435+
auto weightMaskBlock = m_params.weightMasks[layer].block(0, begin, inSize, numCols);
436+
437+
#ifdef ADADELTA
431438
auto weightsEg2Block = m_params.weightsEg2[layer].block(0, begin, inSize, numCols);
432439
auto biasEg2Block = m_params.outputBiasEg2[layer].block(0, begin, 1, numCols);
433-
434440
auto weightsRMSd2Block = m_params.weightsRMSd2[layer].block(0, begin, inSize, numCols);
435441
auto biasRMSd2Block = m_params.outputBiasRMSd2[layer].block(0, begin, 1, numCols);
442+
#endif
436443

437-
auto weightMaskBlock = m_params.weightMasks[layer].block(0, begin, inSize, numCols);
444+
#ifdef SGDM
445+
auto weightsLastUpdateBlock = m_params.weightsLastUpdate[layer].block(0, begin, inSize, numCols);
446+
auto outputBiasLastUpdateBlock = m_params.outputBiasLastUpdate[layer].block(0, begin, 1, numCols);
447+
#endif
438448

439449
#define L1_REG
440450
#ifdef L1_REG
@@ -476,20 +486,24 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
476486
NNMatrix weightReg = NNMatrix::Zero(weightsBlock.rows(), weightsBlock.cols());
477487
#endif
478488

479-
// update Eg2 (ADADELTA)
480-
float decay = 0.95f;
481-
float e = 1e-6f;
489+
#ifdef ADADELTA
490+
float decay = 0.9f;
491+
float e = 1e-7f;
482492
weightsEg2Block.array() *= decay;
483493
weightsEg2Block.array() += (weightsGradientsBlock.array() * weightsGradientsBlock.array()) * (1.0f - decay);
484494
biasEg2Block.array() *= decay;
485495
biasEg2Block.array() += (biasGradientsBlock.array() * biasGradientsBlock.array()) * (1.0f - decay);
486496

487-
// ADADELTA
488-
NNMatrix weightDelta = -weightsGradientsBlock.array() * (weightsRMSd2Block.array() + e).sqrt() / (weightsEg2Block.array() + e).sqrt() + weightReg.array();
497+
NNMatrix weightDelta = -weightsGradientsBlock.array() * (weightsRMSd2Block.array() + e).sqrt() / (weightsEg2Block.array() + e).sqrt() /*+ weightReg.array()*/;
489498
NNVector biasDelta = -biasGradientsBlock.array() * (biasRMSd2Block.array() + e).sqrt() / (biasEg2Block.array() + e).sqrt();
499+
#endif
490500

491-
//NNMatrix weightDelta = -weightsGradientsBlock.array() * learningRate /*+ weightReg.array()*/;
492-
//NNVector biasDelta = -biasGradientsBlock.array() * learningRate;
501+
#ifdef SGDM
502+
float lr = 0.000001f;
503+
float momentum = 0.95f;
504+
NNMatrix weightDelta = -weightsGradientsBlock.array() * lr + momentum * weightsLastUpdateBlock.array()/*+ weightReg.array()*/;
505+
NNVector biasDelta = -biasGradientsBlock.array() * lr + momentum * outputBiasLastUpdateBlock.array();
506+
#endif
493507

494508
weightsBlock += weightDelta;
495509
weightsBlock.array() *= weightMaskBlock.array();
@@ -501,11 +515,17 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
501515
throw LearningRateException();
502516
}
503517

504-
// ADADELTA
518+
#ifdef ADADELTA
505519
weightsRMSd2Block *= decay;
506520
weightsRMSd2Block.array() += weightDelta.array() * weightDelta.array() * (1.0f - decay);
507521
biasRMSd2Block *= decay;
508522
biasRMSd2Block.array() += biasDelta.array() * biasDelta.array() * (1.0f - decay);
523+
#endif
524+
525+
#ifdef SGDM
526+
weightsLastUpdateBlock = weightDelta;
527+
outputBiasLastUpdateBlock = biasDelta;
528+
#endif
509529
}
510530

511531
} // parallel
@@ -782,13 +802,19 @@ void FCANN<ACTF, ACTFLast>::InitializeOptimizationState_()
782802
m_params.weightsRMSd2.resize(m_params.weights.size());
783803
m_params.outputBiasRMSd2.resize(m_params.outputBias.size());
784804

805+
m_params.weightsLastUpdate.resize(m_params.weights.size());
806+
m_params.outputBiasLastUpdate.resize(m_params.outputBias.size());
807+
785808
for (size_t i = 0; i < m_params.weights.size(); ++i)
786809
{
787810
m_params.outputBiasEg2[i] = NNVector::Zero(m_params.outputBias[i].cols());
788811
m_params.weightsEg2[i] = NNMatrix::Zero(m_params.weights[i].rows(), m_params.weights[i].cols());
789812

790813
m_params.outputBiasRMSd2[i] = NNVector::Zero(m_params.outputBias[i].cols());
791814
m_params.weightsRMSd2[i] = NNMatrix::Zero(m_params.weights[i].rows(), m_params.weights[i].cols());
815+
816+
m_params.outputBiasLastUpdate[i] = NNVector::Zero(m_params.outputBias[i].cols());
817+
m_params.weightsLastUpdate[i] = NNMatrix::Zero(m_params.weights[i].rows(), m_params.weights[i].cols());
792818
}
793819
}
794820

0 commit comments

Comments
 (0)