Skip to content

Commit 25ea87a

Browse files
committed
use the standardization as the default scaling instead of range scaler
1 parent fd0a506 commit 25ea87a

5 files changed

Lines changed: 50 additions & 15 deletions

File tree

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,7 @@
501501
<dependency>
502502
<groupId>com.github.chen0040</groupId>
503503
<artifactId>java-data-frame</artifactId>
504-
<version>1.0.9</version>
504+
<version>1.0.11</version>
505505
</dependency>
506506

507507

src/main/java/com/github/chen0040/mlp/ann/MLP.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.github.chen0040.mlp.enums.LearningMethod;
66
import com.github.chen0040.mlp.enums.WeightUpdateMode;
77
import com.github.chen0040.mlp.functions.RangeScaler;
8+
import lombok.Getter;
9+
import lombok.Setter;
810

911
import java.util.ArrayList;
1012
import java.util.List;
@@ -13,14 +15,14 @@
1315
/**
1416
* Created by xschen on 21/8/15.
1517
*/
18+
@Getter
19+
@Setter
1620
public abstract class MLP extends MLPNet {
1721
private Standardization inputNormalization;
18-
private RangeScaler outputNormalization;
22+
private Standardization outputNormalization = new Standardization();
1923

2024
private boolean adaptiveLearningRateEnabled = false;
2125

22-
23-
2426
private boolean normalizeOutputs;
2527

2628
public MLP(){
@@ -55,7 +57,7 @@ public void train(DataFrame batch, int training_epoches)
5557
targets.add(target);
5658
}
5759
}
58-
outputNormalization = new RangeScaler(targets);
60+
outputNormalization.fit(targets);
5961
}
6062

6163
double[][][] dE_dwji_prev = null;

src/main/java/com/github/chen0040/mlp/ann/regression/MLPRegression.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import com.github.chen0040.data.frame.DataFrame;
44
import com.github.chen0040.data.frame.DataRow;
5+
import com.github.chen0040.data.utils.transforms.Standardization;
56
import com.github.chen0040.mlp.enums.LearningMethod;
67
import com.github.chen0040.mlp.enums.WeightUpdateMode;
78
import com.github.chen0040.mlp.functions.Identity;
9+
import com.github.chen0040.mlp.functions.RangeScaler;
810
import com.github.chen0040.mlp.functions.Sigmoid;
911
import com.github.chen0040.mlp.functions.TransferFunction;
1012
import lombok.Getter;
@@ -61,6 +63,10 @@ public void enabledAdaptiveLearningRate(boolean enabled){
6163
@Setter
6264
private double learningRate = 0.2;
6365

66+
@Getter
67+
@Setter
68+
private Standardization outputNormalization = new Standardization();
69+
6470
public MLPRegression(){
6571
epoches = 1000;
6672

@@ -89,6 +95,7 @@ public void fit(DataFrame batch) {
8995

9096
mlp = new MLPWithNumericOutput();
9197
mlp.setNormalizeOutputs(true);
98+
mlp.setOutputNormalization(outputNormalization);
9299
mlp.setMiniBatchSize(miniBatchSize);
93100
mlp.setLearningMethod(learningMethod);
94101
mlp.setWeightUpdateMode(weightUpdateMode);

src/main/java/com/github/chen0040/mlp/functions/RangeScaler.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.github.chen0040.mlp.functions;
22

33

4+
import com.github.chen0040.data.utils.transforms.Standardization;
45
import lombok.Getter;
56
import lombok.Setter;
67

@@ -14,13 +15,14 @@
1415
*/
1516
@Getter
1617
@Setter
17-
public class RangeScaler implements Cloneable {
18+
public class RangeScaler extends Standardization {
1819

1920
private final Map<Integer, Double> minValue = new HashMap<>();
2021
private final Map<Integer, Double> maxValue = new HashMap<>();
2122

2223

23-
public RangeScaler(List<double[]> targets) {
24+
@Override
25+
public void fit(List<double[]> targets) {
2426
for(int i = 0; i < targets.size(); ++i){
2527
double[] values = targets.get(i);
2628
for(int j=0; j < values.length; ++j) {
@@ -42,6 +44,7 @@ public Object clone() throws CloneNotSupportedException {
4244
}
4345

4446

47+
@Override
4548
public double[] standardize(double[] target) {
4649
double[] result = new double[target.length];
4750
for(int i=0; i < result.length; ++i){
@@ -51,6 +54,7 @@ public double[] standardize(double[] target) {
5154
}
5255

5356

57+
@Override
5458
public double[] revert(double[] target) {
5559
double[] result = new double[target.length];
5660
for(int i=0; i < result.length; ++i){

src/test/java/com/github/chen0040/mlp/ann/regression/MLPRegressionUnitTest.java

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import com.github.chen0040.data.frame.DataQuery;
66
import com.github.chen0040.data.frame.Sampler;
77
import com.github.chen0040.data.utils.TupleTwo;
8+
import com.github.chen0040.data.utils.transforms.Standardization;
89
import com.github.chen0040.mlp.enums.LearningMethod;
910
import com.github.chen0040.mlp.enums.WeightUpdateMode;
1011
import com.github.chen0040.mlp.functions.Identity;
12+
import com.github.chen0040.mlp.functions.RangeScaler;
1113
import com.github.chen0040.mlp.functions.ReLU;
1214
import com.github.chen0040.mlp.functions.Sigmoid;
1315
import com.github.chen0040.mlp.utils.FileUtils;
@@ -40,18 +42,33 @@ public static double randn(){
4042
return r * Math.sin(theta);
4143
}
4244

45+
46+
4347
@Test
44-
public void testSimple() {
45-
InputStream inputStream = FileUtils.getResource("heart_scale");
48+
public void test_simple_regression() {
49+
DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
50+
.newInput("x1")
51+
.newInput("x2")
52+
.newOutput("y")
53+
.end();
54+
55+
// y = 4 + 0.5 * x1 + 0.2 * x2
56+
Sampler.DataSampleBuilder sampler = new Sampler()
57+
.forColumn("x1").generate((name, index) -> randn() * 0.3 + index / 100.0)
58+
.forColumn("x2").generate((name, index) -> randn() * 0.3 + index * index / 10000.0)
59+
.forColumn("y").generate((name, index) -> 4 + 0.5 * index / 100.0 + 0.2 * index * index / 10000.0 + randn() * 0.3)
60+
.end();
4661

47-
DataFrame dataFrame = DataQuery.libsvm().from(inputStream).build();
62+
DataFrame data = schema.build();
4863

49-
System.out.println(dataFrame.head(10));
64+
data = sampler.sample(data, 200);
5065

51-
TupleTwo<DataFrame, DataFrame> miniFrames = dataFrame.shuffle().split(0.9);
66+
TupleTwo<DataFrame, DataFrame> frames = data.shuffle().split(0.9);
5267

53-
DataFrame trainingData = miniFrames._1();
54-
DataFrame crossValidationData = miniFrames._2();
68+
DataFrame trainingData = frames._1();
69+
System.out.println(trainingData.head(10));
70+
71+
DataFrame crossValidationData = frames._2();
5572

5673
MLPRegression regression = new MLPRegression();
5774
regression.setHiddenLayers(8);
@@ -68,7 +85,7 @@ public void testSimple() {
6885
}
6986

7087
@Test
71-
public void test_simple_regression() {
88+
public void test_simple_regression_scaled_output() {
7289
DataQuery.DataFrameQueryBuilder schema = DataQuery.blank()
7390
.newInput("x1")
7491
.newInput("x2")
@@ -95,6 +112,7 @@ public void test_simple_regression() {
95112

96113
MLPRegression regression = new MLPRegression();
97114
regression.setHiddenLayers(8);
115+
regression.setOutputNormalization(new RangeScaler());
98116
regression.setEpoches(1000);
99117
regression.fit(trainingData);
100118

@@ -135,6 +153,7 @@ public void test_simple_regression_weight_constraint() {
135153

136154
MLPRegression regression = new MLPRegression();
137155
regression.setHiddenLayers(8);
156+
regression.setOutputNormalization(new RangeScaler());
138157
regression.setEpoches(1000);
139158
regression.setWeightConstraint(80);
140159
regression.fit(trainingData);
@@ -218,6 +237,7 @@ public void test_simple_regression_mini_batch_gradient_descend() {
218237
MLPRegression regression = new MLPRegression();
219238
regression.setWeightUpdateMode(WeightUpdateMode.MiniBatchGradientDescend);
220239
regression.setMiniBatchSize(20);
240+
regression.setOutputNormalization(new RangeScaler());
221241
regression.setHiddenLayers(8);
222242
regression.setEpoches(1000);
223243
regression.fit(trainingData);
@@ -261,6 +281,7 @@ public void test_simple_regression_mini_batch_gradient_descend_L2_regularization
261281
regression.setWeightUpdateMode(WeightUpdateMode.MiniBatchGradientDescend);
262282
regression.setMiniBatchSize(20);
263283
regression.setHiddenLayers(8);
284+
regression.setOutputNormalization(new RangeScaler());
264285
regression.setL2Penalty(0.001);
265286
regression.setEpoches(1000);
266287
regression.fit(trainingData);
@@ -305,6 +326,7 @@ public void test_simple_regression_mini_batch_gradient_descend_adaptive_learning
305326
regression.enabledAdaptiveLearningRate(true);
306327
regression.setMiniBatchSize(20);
307328
regression.setHiddenLayers(8);
329+
regression.setOutputNormalization(new RangeScaler());
308330
regression.setEpoches(1000);
309331
regression.fit(trainingData);
310332

0 commit comments

Comments
 (0)