3131#include " omp_scoped_thread_limiter.h"
3232#include " random_device.h"
3333
34+ #define SGDM
35+ // #define ADADELTA
36+
37+ #if defined(SGDM) && defined(ADADELTA)
38+ #error Only select one training method!
39+ #elif !defined(SGDM) && !defined(ADADELTA)
40+ #error Must select one training method!
41+ #endif
42+
3443inline void EnableNanInterrupt ()
3544{
3645 _MM_SET_EXCEPTION_MASK (_MM_GET_EXCEPTION_MASK () & ~_MM_MASK_INVALID);
@@ -398,11 +407,6 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
398407 assert (grad.biasGradients .size () == m_params.outputBias .size ());
399408 assert (grad.weightGradients .size () == grad.biasGradients .size ());
400409
401- /* // for SGD + M
402- m_params.weightsLastUpdate.resize(m_params.weights.size());
403- m_params.outputBiasLastUpdate.resize(m_params.outputBias.size());
404- */
405-
406410 if (m_params.weightsEg2 .size () != m_params.weights .size ())
407411 {
408412 InitializeOptimizationState_ ();
@@ -428,13 +432,19 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
428432 auto weightsGradientsBlock = grad.weightGradients [layer].block (0 , begin, inSize, numCols);
429433 auto biasGradientsBlock = grad.biasGradients [layer].block (0 , begin, 1 , numCols);
430434
435+ auto weightMaskBlock = m_params.weightMasks [layer].block (0 , begin, inSize, numCols);
436+
437+ #ifdef ADADELTA
431438 auto weightsEg2Block = m_params.weightsEg2 [layer].block (0 , begin, inSize, numCols);
432439 auto biasEg2Block = m_params.outputBiasEg2 [layer].block (0 , begin, 1 , numCols);
433-
434440 auto weightsRMSd2Block = m_params.weightsRMSd2 [layer].block (0 , begin, inSize, numCols);
435441 auto biasRMSd2Block = m_params.outputBiasRMSd2 [layer].block (0 , begin, 1 , numCols);
442+ #endif
436443
437- auto weightMaskBlock = m_params.weightMasks [layer].block (0 , begin, inSize, numCols);
444+ #ifdef SGDM
445+ auto weightsLastUpdateBlock = m_params.weightsLastUpdate [layer].block (0 , begin, inSize, numCols);
446+ auto outputBiasLastUpdateBlock = m_params.outputBiasLastUpdate [layer].block (0 , begin, 1 , numCols);
447+ #endif
438448
439449 #define L1_REG
440450 #ifdef L1_REG
@@ -476,20 +486,24 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
476486 NNMatrix weightReg = NNMatrix::Zero (weightsBlock.rows (), weightsBlock.cols ());
477487 #endif
478488
479- // update Eg2 ( ADADELTA)
480- float decay = 0 .95f ;
481- float e = 1e-6f ;
489+ # ifdef ADADELTA
490+ float decay = 0 .9f ;
491+ float e = 1e-7f ;
482492 weightsEg2Block.array () *= decay;
483493 weightsEg2Block.array () += (weightsGradientsBlock.array () * weightsGradientsBlock.array ()) * (1 .0f - decay);
484494 biasEg2Block.array () *= decay;
485495 biasEg2Block.array () += (biasGradientsBlock.array () * biasGradientsBlock.array ()) * (1 .0f - decay);
486496
487- // ADADELTA
488- NNMatrix weightDelta = -weightsGradientsBlock.array () * (weightsRMSd2Block.array () + e).sqrt () / (weightsEg2Block.array () + e).sqrt () + weightReg.array ();
497+ NNMatrix weightDelta = -weightsGradientsBlock.array () * (weightsRMSd2Block.array () + e).sqrt () / (weightsEg2Block.array () + e).sqrt () /* + weightReg.array()*/ ;
489498 NNVector biasDelta = -biasGradientsBlock.array () * (biasRMSd2Block.array () + e).sqrt () / (biasEg2Block.array () + e).sqrt ();
499+ #endif
490500
491- // NNMatrix weightDelta = -weightsGradientsBlock.array() * learningRate /*+ weightReg.array()*/;
492- // NNVector biasDelta = -biasGradientsBlock.array() * learningRate;
501+ #ifdef SGDM
502+ float lr = 0 .000001f ;
503+ float momentum = 0 .95f ;
504+ NNMatrix weightDelta = -weightsGradientsBlock.array () * lr + momentum * weightsLastUpdateBlock.array ()/* + weightReg.array()*/ ;
505+ NNVector biasDelta = -biasGradientsBlock.array () * lr + momentum * outputBiasLastUpdateBlock.array ();
506+ #endif
493507
494508 weightsBlock += weightDelta;
495509 weightsBlock.array () *= weightMaskBlock.array ();
@@ -501,11 +515,17 @@ void FCANN<ACTF, ACTFLast>::ApplyWeightUpdates(const Gradients &grad, float /*le
501515 throw LearningRateException ();
502516 }
503517
504- // ADADELTA
518+ # ifdef ADADELTA
505519 weightsRMSd2Block *= decay;
506520 weightsRMSd2Block.array () += weightDelta.array () * weightDelta.array () * (1 .0f - decay);
507521 biasRMSd2Block *= decay;
508522 biasRMSd2Block.array () += biasDelta.array () * biasDelta.array () * (1 .0f - decay);
523+ #endif
524+
525+ #ifdef SGDM
526+ weightsLastUpdateBlock = weightDelta;
527+ outputBiasLastUpdateBlock = biasDelta;
528+ #endif
509529 }
510530
511531 } // parallel
@@ -782,13 +802,19 @@ void FCANN<ACTF, ACTFLast>::InitializeOptimizationState_()
782802 m_params.weightsRMSd2 .resize (m_params.weights .size ());
783803 m_params.outputBiasRMSd2 .resize (m_params.outputBias .size ());
784804
805+ m_params.weightsLastUpdate .resize (m_params.weights .size ());
806+ m_params.outputBiasLastUpdate .resize (m_params.outputBias .size ());
807+
785808 for (size_t i = 0 ; i < m_params.weights .size (); ++i)
786809 {
787810 m_params.outputBiasEg2 [i] = NNVector::Zero (m_params.outputBias [i].cols ());
788811 m_params.weightsEg2 [i] = NNMatrix::Zero (m_params.weights [i].rows (), m_params.weights [i].cols ());
789812
790813 m_params.outputBiasRMSd2 [i] = NNVector::Zero (m_params.outputBias [i].cols ());
791814 m_params.weightsRMSd2 [i] = NNMatrix::Zero (m_params.weights [i].rows (), m_params.weights [i].cols ());
815+
816+ m_params.outputBiasLastUpdate [i] = NNVector::Zero (m_params.outputBias [i].cols ());
817+ m_params.weightsLastUpdate [i] = NNMatrix::Zero (m_params.weights [i].rows (), m_params.weights [i].cols ());
792818 }
793819}
794820
0 commit comments