@@ -17,12 +17,16 @@ namespace TensorFlowNET.Examples
1717 public class LogisticRegression : Python , IExample
1818 {
1919 public int Priority => 4 ;
20- public bool Enabled => true ;
20+ public bool Enabled { get ; set ; } = true ;
2121 public string Name => "Logistic Regression" ;
2222
23+ public int training_epochs = 10 ;
24+ public int ? train_size = null ;
25+ public int validation_size = 5000 ;
26+ public int ? test_size = null ;
27+ public int batch_size = 100 ;
28+
2329 private float learning_rate = 0.01f ;
24- private int training_epochs = 10 ;
25- private int batch_size = 100 ;
2630 private int display_step = 1 ;
2731
2832 Datasets mnist ;
@@ -96,7 +100,7 @@ public bool Run()
96100
97101 public void PrepareData ( )
98102 {
99- mnist = MnistDataSet . read_data_sets ( "mnist" , one_hot : true ) ;
103+ mnist = MnistDataSet . read_data_sets ( "mnist" , one_hot : true , train_size : train_size , validation_size : validation_size , test_size : test_size ) ;
100104 }
101105
102106 public void SaveModel ( Session sess )
@@ -139,7 +143,7 @@ public void Predict()
139143 if ( results . argmax ( ) == ( batch_ys [ 0 ] as NDArray ) . argmax ( ) )
140144 print ( "predicted OK!" ) ;
141145 else
142- throw new ValueError ( "predict error, maybe 90% accuracy" ) ;
146+ throw new ValueError ( "predict error, should be 90% accuracy" ) ;
143147 } ) ;
144148 }
145149 }
0 commit comments