Skip to content

Commit c06b078

Browse files
committed
Implement range scaler for regression model to improve its performance
1 parent 7a1517e commit c06b078

14 files changed

Lines changed: 157 additions & 104 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Add the following dependency to your POM file:
99
<dependency>
1010
<groupId>com.github.chen0040</groupId>
1111
<artifactId>java-ann-mlp</artifactId>
12-
<version>1.0.1</version>
12+
<version>1.0.2</version>
1313
</dependency>
1414
```
1515

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
<groupId>com.github.chen0040</groupId>
88
<artifactId>java-ann-mlp</artifactId>
9-
<version>1.0.2</version>
9+
<version>1.0.3</version>
1010

1111
<licenses>
1212
<license>

src/main/java/com/github/chen0040/mlp/ann/MLP.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package com.github.chen0040.mlp.ann;
22
import com.github.chen0040.data.frame.DataFrame;
33
import com.github.chen0040.data.frame.DataRow;
4+
import com.github.chen0040.data.utils.Scaler;
45
import com.github.chen0040.data.utils.transforms.Standardization;
6+
import com.github.chen0040.mlp.functions.RangeScaler;
57

68
import java.util.ArrayList;
79
import java.util.List;
@@ -12,7 +14,7 @@
1214
*/
1315
public abstract class MLP extends MLPNet {
1416
private Standardization inputNormalization;
15-
private Standardization outputNormalization;
17+
private RangeScaler outputNormalization;
1618

1719
private boolean normalizeOutputs;
1820

@@ -22,7 +24,7 @@ public void copy(MLPNet rhs) throws CloneNotSupportedException {
2224

2325
MLP rhs2 = (MLP)rhs;
2426
inputNormalization = rhs2.inputNormalization == null ? null : (Standardization)rhs2.inputNormalization.clone();
25-
outputNormalization = rhs2.outputNormalization == null ? null : (Standardization)rhs2.outputNormalization.clone();
27+
outputNormalization = rhs2.outputNormalization == null ? null : (RangeScaler) rhs2.outputNormalization.clone();
2628
normalizeOutputs = rhs2.normalizeOutputs;
2729
}
2830

@@ -54,7 +56,7 @@ public void train(DataFrame batch, int training_epoches)
5456
targets.add(target);
5557
}
5658
}
57-
outputNormalization = new Standardization(targets);
59+
outputNormalization = new RangeScaler(targets);
5860
}
5961

6062

src/main/java/com/github/chen0040/mlp/ann/MLPLayer.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66
import com.github.chen0040.mlp.functions.TransferFunction;
77

88
import java.util.ArrayList;
9-
import java.util.Random;
9+
import java.util.List;
1010

1111

1212
//default network assumes input and output are in the range of [0, 1]
1313
public class MLPLayer implements Cloneable {
14-
private static Random rand = new Random();
1514
private TransferFunction transfer = new LogSig();
16-
private ArrayList<MLPNeuron> neurons;
15+
private List<MLPNeuron> neurons;
1716

1817
public void copy(MLPLayer rhs){
1918
transfer = rhs.transfer == null ? null : (TransferFunction) ((AbstractTransferFunction)rhs.transfer).clone();
@@ -69,7 +68,7 @@ public void setTransfer(TransferFunction transfer) {
6968
this.transfer = transfer;
7069
}
7170

72-
public ArrayList<MLPNeuron> getNeurons() {
71+
public List<MLPNeuron> getNeurons() {
7372
return neurons;
7473
}
7574

@@ -90,7 +89,7 @@ public double[] forward_propagate(double[] input)
9089
return output;
9190
}
9291

93-
protected void adjust_weights(double[] input, double learningRate, double momentum)
92+
protected void adjust_weights(double[] input, double learningRate)
9493
{
9594
for(int j=0; j< neurons.size(); j++)
9695
{
@@ -99,13 +98,12 @@ protected void adjust_weights(double[] input, double learningRate, double moment
9998
for(int i=0; i < dimension; ++i) {
10099

101100
double sink_error = neuron.error;
102-
double dWeight = neuron.getWeightDelta(i);
101+
103102
double weight = neuron.getWeight(i);
104103

105104
double dw = learningRate * sink_error * input[i];
106-
weight += (dw + momentum * dWeight);
107-
dWeight = dw;
108-
neuron.setWeightDelta(i, dWeight);
105+
weight += dw;
106+
neuron.setWeightDelta(i, dw);
109107
neuron.setWeight(i, weight);
110108
}
111109
}
@@ -130,7 +128,10 @@ public double[] back_propagate(double[] error)
130128
{
131129
MLPNeuron neuron= neurons.get(i);
132130
double y = neuron.output;
133-
neuron.error = y * (1-y) * error[i];
131+
double[] values = neuron.values;
132+
double hx = neuron.getValue(values);
133+
134+
neuron.error = transfer.gradient(hx, y) * error[i];
134135
}
135136

136137
int k = dimension();
@@ -149,4 +150,6 @@ public double[] back_propagate(double[] error)
149150

150151
return propagated_error;
151152
}
153+
154+
152155
}

src/main/java/com/github/chen0040/mlp/ann/MLPNet.java

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
//default network assumes input and output are in the range of [0, 1]
1414
public class MLPNet implements Cloneable {
1515
protected MLPLayer inputLayer =null;
16-
protected MLPLayer outputLayer =null;
16+
public MLPLayer outputLayer =null;
1717

1818
protected List<MLPLayer> hiddenLayers;
1919

@@ -23,9 +23,6 @@ public class MLPNet implements Cloneable {
2323
@Setter
2424
protected double learningRate =0.25; //learning rate
2525

26-
@Getter
27-
@Setter
28-
protected double momentum =0.9; //momentum term for \Delta w[i][j]
2926

3027
public void copy(MLPNet rhs) throws CloneNotSupportedException {
3128
inputLayer = rhs.inputLayer == null ? null : (MLPLayer)rhs.inputLayer.clone();
@@ -37,7 +34,6 @@ public void copy(MLPNet rhs) throws CloneNotSupportedException {
3734
}
3835

3936
learningRate = rhs.learningRate;
40-
momentum = rhs.momentum;
4137
}
4238

4339
public MLPLayer createInputLayer(int dimension){
@@ -55,7 +51,7 @@ public MLPLayer createOutputLayer(int dimension){
5551

5652
public MLPNet()
5753
{
58-
hiddenLayers = new ArrayList<MLPLayer>();
54+
hiddenLayers = new ArrayList<>();
5955
}
6056

6157

@@ -81,8 +77,10 @@ public double train(double[] input, double[] target)
8177
propagated_output = hiddenLayers.get(i).forward_propagate(propagated_output);
8278
}
8379
propagated_output = outputLayer.forward_propagate(propagated_output);
84-
80+
81+
8582
double error = get_target_error(target);
83+
8684

8785
//backward propagate
8886
double[] propagated_error = outputLayer.back_propagate(minus(target, propagated_output));
@@ -93,10 +91,10 @@ public double train(double[] input, double[] target)
9391
//adjust weights
9492
double[] input2 = inputLayer.output();
9593
for(int i = 0; i < hiddenLayers.size(); ++i){
96-
hiddenLayers.get(i).adjust_weights(input2, getLearningRate(), getMomentum());
94+
hiddenLayers.get(i).adjust_weights(input2, getLearningRate());
9795
input2 = hiddenLayers.get(i).output();
9896
}
99-
outputLayer.adjust_weights(input2, getLearningRate(), getMomentum());
97+
outputLayer.adjust_weights(input2, getLearningRate());
10098

10199

102100
return error;
@@ -109,6 +107,7 @@ public double[] minus(double[] a, double[] b){
109107
}
110108
return c;
111109
}
110+
112111

113112
protected double get_target_error(double[] target)
114113
{
@@ -124,17 +123,6 @@ protected double get_target_error(double[] target)
124123
return t_error;
125124
}
126125

127-
public double test(double[] input, double[] target)
128-
{
129-
double[] propagated_output = inputLayer.setOutput(input);
130-
for(int i=0; i < hiddenLayers.size(); ++i) {
131-
propagated_output = hiddenLayers.get(i).forward_propagate(propagated_output);
132-
}
133-
propagated_output = outputLayer.forward_propagate(propagated_output);
134-
135-
return get_target_error(target);
136-
}
137-
138126
public double[] transform(double[] input)
139127
{
140128
double[] propagated_output = inputLayer.setOutput(input);

src/main/java/com/github/chen0040/mlp/ann/MLPNeuron.java

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66

77

88
public class MLPNeuron implements Cloneable {
9-
public double bias_weight = 0;
10-
public double bias = 0;
9+
double bias_weight = 0;
10+
private double bias = 0;
1111

12-
public double output = 0;
13-
public double error = 0;
12+
double output = 0;
13+
double error = 0;
14+
15+
double[] values = null;
1416

1517
private static Random rand = new Random();
1618

@@ -48,22 +50,12 @@ public double getWeight(int index){
4850
if(weights.containsKey(index)){
4951
return weights.get(index);
5052
}else{
51-
double weight = rand.nextDouble() - 0.5;
53+
double weight = (rand.nextDouble() - 0.5) / 10;
5254
weights.put(index, weight);
5355
return weight;
5456
}
5557
}
5658

57-
public double getWeightDelta(int index){
58-
if(weightDeltas.containsKey(index)){
59-
return weightDeltas.get(index);
60-
}else{
61-
double dweight = rand.nextDouble() - 0.5;
62-
weightDeltas.put(index, dweight);
63-
return dweight;
64-
}
65-
}
66-
6759
public void setWeightDelta(int index, double val){
6860
weightDeltas.put(index, val);
6961
}
@@ -77,12 +69,13 @@ public MLPNeuron()
7769
bias_weight =rand.nextDouble()-0.5;
7870
bias =-1;
7971

80-
weights = new HashMap<Integer, Double>();
81-
weightDeltas = new HashMap<Integer, Double>();
72+
weights = new HashMap<>();
73+
weightDeltas = new HashMap<>();
8274
}
8375

84-
public double getValue(double[] x)
76+
double getValue(double[] x)
8577
{
78+
values = x;
8679
double sum=0;
8780

8881
for(int i=0; i < x.length; i++)

src/main/java/com/github/chen0040/mlp/ann/regression/MLPRegression.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import com.github.chen0040.data.frame.DataFrame;
44
import com.github.chen0040.data.frame.DataRow;
5+
import com.github.chen0040.mlp.functions.LogSig;
6+
import com.github.chen0040.mlp.functions.TransferFunction;
57
import lombok.Getter;
68
import lombok.Setter;
79

@@ -70,14 +72,18 @@ public void fit(DataFrame batch) {
7072
mlp = new MLPWithNumericOutput();
7173
mlp.setNormalizeOutputs(true);
7274

75+
TransferFunction transferFunction = new LogSig();
76+
77+
7378
int dimension = batch.row(0).toArray().length;
7479

7580
mlp.setLearningRate(learningRate);
7681
mlp.createInputLayer(dimension);
7782
for (int hiddenLayerNeuronCount : hiddenLayers){
78-
mlp.addHiddenLayer(hiddenLayerNeuronCount);
83+
mlp.addHiddenLayer(hiddenLayerNeuronCount, transferFunction);
7984
}
8085
mlp.createOutputLayer(1);
86+
mlp.outputLayer.setTransfer(transferFunction);
8187

8288
mlp.train(batch, epoches);
8389
}

src/main/java/com/github/chen0040/mlp/functions/BiPolarLogSig.java

Lines changed: 0 additions & 14 deletions
This file was deleted.

src/main/java/com/github/chen0040/mlp/functions/IdentityFunction.java

Lines changed: 0 additions & 16 deletions
This file was deleted.

src/main/java/com/github/chen0040/mlp/functions/LogSig.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ public double calculate(double x)
1010
return 1/(Math.exp(-x)+1);
1111
}
1212

13+
14+
@Override public double gradient(double hx, double y) {
15+
y = calculate(hx);
16+
return y * (1-y);
17+
}
18+
19+
1320
@Override
1421
public Object clone(){
1522
return new LogSig();

0 commit comments

Comments
 (0)