55import com .github .chen0040 .data .frame .DataQuery ;
66import com .github .chen0040 .data .frame .Sampler ;
77import com .github .chen0040 .data .utils .TupleTwo ;
8+ import com .github .chen0040 .data .utils .transforms .Standardization ;
89import com .github .chen0040 .mlp .enums .LearningMethod ;
910import com .github .chen0040 .mlp .enums .WeightUpdateMode ;
1011import com .github .chen0040 .mlp .functions .Identity ;
12+ import com .github .chen0040 .mlp .functions .RangeScaler ;
1113import com .github .chen0040 .mlp .functions .ReLU ;
1214import com .github .chen0040 .mlp .functions .Sigmoid ;
1315import com .github .chen0040 .mlp .utils .FileUtils ;
@@ -40,18 +42,33 @@ public static double randn(){
4042 return r * Math .sin (theta );
4143 }
4244
45+
46+
4347 @ Test
44- public void testSimple () {
45- InputStream inputStream = FileUtils .getResource ("heart_scale" );
48+ public void test_simple_regression () {
49+ DataQuery .DataFrameQueryBuilder schema = DataQuery .blank ()
50+ .newInput ("x1" )
51+ .newInput ("x2" )
52+ .newOutput ("y" )
53+ .end ();
54+
55+ // y = 4 + 0.5 * x1 + 0.2 * x2
56+ Sampler .DataSampleBuilder sampler = new Sampler ()
57+ .forColumn ("x1" ).generate ((name , index ) -> randn () * 0.3 + index / 100.0 )
58+ .forColumn ("x2" ).generate ((name , index ) -> randn () * 0.3 + index * index / 10000.0 )
59+ .forColumn ("y" ).generate ((name , index ) -> 4 + 0.5 * index / 100.0 + 0.2 * index * index / 10000.0 + randn () * 0.3 )
60+ .end ();
4661
47- DataFrame dataFrame = DataQuery . libsvm (). from ( inputStream ) .build ();
62+ DataFrame data = schema .build ();
4863
49- System . out . println ( dataFrame . head ( 10 ) );
64+ data = sampler . sample ( data , 200 );
5065
51- TupleTwo <DataFrame , DataFrame > miniFrames = dataFrame .shuffle ().split (0.9 );
66+ TupleTwo <DataFrame , DataFrame > frames = data .shuffle ().split (0.9 );
5267
53- DataFrame trainingData = miniFrames ._1 ();
54- DataFrame crossValidationData = miniFrames ._2 ();
68+ DataFrame trainingData = frames ._1 ();
69+ System .out .println (trainingData .head (10 ));
70+
71+ DataFrame crossValidationData = frames ._2 ();
5572
5673 MLPRegression regression = new MLPRegression ();
5774 regression .setHiddenLayers (8 );
@@ -68,7 +85,7 @@ public void testSimple() {
6885 }
6986
7087 @ Test
71- public void test_simple_regression () {
88+ public void test_simple_regression_scaled_output () {
7289 DataQuery .DataFrameQueryBuilder schema = DataQuery .blank ()
7390 .newInput ("x1" )
7491 .newInput ("x2" )
@@ -95,6 +112,7 @@ public void test_simple_regression() {
95112
96113 MLPRegression regression = new MLPRegression ();
97114 regression .setHiddenLayers (8 );
115+ regression .setOutputNormalization (new RangeScaler ());
98116 regression .setEpoches (1000 );
99117 regression .fit (trainingData );
100118
@@ -135,6 +153,7 @@ public void test_simple_regression_weight_constraint() {
135153
136154 MLPRegression regression = new MLPRegression ();
137155 regression .setHiddenLayers (8 );
156+ regression .setOutputNormalization (new RangeScaler ());
138157 regression .setEpoches (1000 );
139158 regression .setWeightConstraint (80 );
140159 regression .fit (trainingData );
@@ -218,6 +237,7 @@ public void test_simple_regression_mini_batch_gradient_descend() {
218237 MLPRegression regression = new MLPRegression ();
219238 regression .setWeightUpdateMode (WeightUpdateMode .MiniBatchGradientDescend );
220239 regression .setMiniBatchSize (20 );
240+ regression .setOutputNormalization (new RangeScaler ());
221241 regression .setHiddenLayers (8 );
222242 regression .setEpoches (1000 );
223243 regression .fit (trainingData );
@@ -261,6 +281,7 @@ public void test_simple_regression_mini_batch_gradient_descend_L2_regularization
261281 regression .setWeightUpdateMode (WeightUpdateMode .MiniBatchGradientDescend );
262282 regression .setMiniBatchSize (20 );
263283 regression .setHiddenLayers (8 );
284+ regression .setOutputNormalization (new RangeScaler ());
264285 regression .setL2Penalty (0.001 );
265286 regression .setEpoches (1000 );
266287 regression .fit (trainingData );
@@ -305,6 +326,7 @@ public void test_simple_regression_mini_batch_gradient_descend_adaptive_learning
305326 regression .enabledAdaptiveLearningRate (true );
306327 regression .setMiniBatchSize (20 );
307328 regression .setHiddenLayers (8 );
329+ regression .setOutputNormalization (new RangeScaler ());
308330 regression .setEpoches (1000 );
309331 regression .fit (trainingData );
310332
0 commit comments