1313//default network assumes input and output are in the range of [0, 1]
1414public class MLPNet implements Cloneable {
1515 protected MLPLayer inputLayer =null ;
16- protected MLPLayer outputLayer =null ;
16+ public MLPLayer outputLayer =null ;
1717
1818 protected List <MLPLayer > hiddenLayers ;
1919
@@ -23,9 +23,6 @@ public class MLPNet implements Cloneable {
2323 @ Setter
2424 protected double learningRate =0.25 ; //learning rate
2525
26- @ Getter
27- @ Setter
28- protected double momentum =0.9 ; //momentum term for \Delta w[i][j]
2926
3027 public void copy (MLPNet rhs ) throws CloneNotSupportedException {
3128 inputLayer = rhs .inputLayer == null ? null : (MLPLayer )rhs .inputLayer .clone ();
@@ -37,7 +34,6 @@ public void copy(MLPNet rhs) throws CloneNotSupportedException {
3734 }
3835
3936 learningRate = rhs .learningRate ;
40- momentum = rhs .momentum ;
4137 }
4238
4339 public MLPLayer createInputLayer (int dimension ){
@@ -55,7 +51,7 @@ public MLPLayer createOutputLayer(int dimension){
5551
5652 public MLPNet ()
5753 {
58- hiddenLayers = new ArrayList <MLPLayer >();
54+ hiddenLayers = new ArrayList <>();
5955 }
6056
6157
@@ -81,8 +77,10 @@ public double train(double[] input, double[] target)
8177 propagated_output = hiddenLayers .get (i ).forward_propagate (propagated_output );
8278 }
8379 propagated_output = outputLayer .forward_propagate (propagated_output );
84-
80+
81+
8582 double error = get_target_error (target );
83+
8684
8785 //backward propagate
8886 double [] propagated_error = outputLayer .back_propagate (minus (target , propagated_output ));
@@ -93,10 +91,10 @@ public double train(double[] input, double[] target)
9391 //adjust weights
9492 double [] input2 = inputLayer .output ();
9593 for (int i = 0 ; i < hiddenLayers .size (); ++i ){
96- hiddenLayers .get (i ).adjust_weights (input2 , getLearningRate (), getMomentum () );
94+ hiddenLayers .get (i ).adjust_weights (input2 , getLearningRate ());
9795 input2 = hiddenLayers .get (i ).output ();
9896 }
99- outputLayer .adjust_weights (input2 , getLearningRate (), getMomentum () );
97+ outputLayer .adjust_weights (input2 , getLearningRate ());
10098
10199
102100 return error ;
@@ -109,6 +107,7 @@ public double[] minus(double[] a, double[] b){
109107 }
110108 return c ;
111109 }
110+
112111
113112 protected double get_target_error (double [] target )
114113 {
@@ -124,17 +123,6 @@ protected double get_target_error(double[] target)
124123 return t_error ;
125124 }
126125
127- public double test (double [] input , double [] target )
128- {
129- double [] propagated_output = inputLayer .setOutput (input );
130- for (int i =0 ; i < hiddenLayers .size (); ++i ) {
131- propagated_output = hiddenLayers .get (i ).forward_propagate (propagated_output );
132- }
133- propagated_output = outputLayer .forward_propagate (propagated_output );
134-
135- return get_target_error (target );
136- }
137-
138126 public double [] transform (double [] input )
139127 {
140128 double [] propagated_output = inputLayer .setOutput (input );
0 commit comments