Skip to content

Commit 1ff31db

Browse files
committed
fix: Examples project uses all data, unit test uses only small fraction
1 parent 30dde0f commit 1ff31db

File tree

5 files changed

+19
-16
lines changed

5 files changed

+19
-16
lines changed

test/TensorFlowNET.Examples/LinearRegression.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public class LinearRegression : Python, IExample
2020

2121
// Parameters
2222
float learning_rate = 0.01f;
23-
int training_epochs = 1000;
23+
public int TrainingEpochs = 1000;
2424
int display_step = 50;
2525

2626
NDArray train_X, train_Y;
@@ -62,7 +62,7 @@ public bool Run()
6262
sess.run(init);
6363

6464
// Fit all training data
65-
for (int epoch = 0; epoch < training_epochs; epoch++)
65+
for (int epoch = 0; epoch < TrainingEpochs; epoch++)
6666
{
6767
foreach (var (x, y) in zip<float>(train_X, train_Y))
6868
{

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ public class LogisticRegression : Python, IExample
2222

2323
private float learning_rate = 0.01f;
2424
public int TrainingEpochs = 10;
25-
public int DataSize = 5000;
26-
public int TestSize = 5000;
25+
public int? TrainSize = null;
26+
public int ValidationSize = 5000;
27+
public int? TestSize = null;
2728
public int BatchSize = 100;
2829
private int display_step = 1;
2930

@@ -98,7 +99,7 @@ public bool Run()
9899

99100
public void PrepareData()
100101
{
101-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize);
102+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize);
102103
}
103104

104105
public void SaveModel(Session sess)
@@ -141,7 +142,7 @@ public void Predict()
141142
if (results.argmax() == (batch_ys[0] as NDArray).argmax())
142143
print("predicted OK!");
143144
else
144-
throw new ValueError("predict error, maybe 90% accuracy");
145+
throw new ValueError("predict error, should be 90% accuracy");
145146
});
146147
}
147148
}

test/TensorFlowNET.Examples/NearestNeighbor.cs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@ public class NearestNeighbor : Python, IExample
1919
public string Name => "Nearest Neighbor";
2020
Datasets mnist;
2121
NDArray Xtr, Ytr, Xte, Yte;
22-
public int DataSize = 5000;
23-
public int TestBatchSize = 200;
22+
public int? TrainSize = null;
23+
public int ValidationSize = 5000;
24+
public int? TestSize = null;
2425

2526
public bool Run()
2627
{
@@ -64,10 +65,10 @@ public bool Run()
6465

6566
public void PrepareData()
6667
{
67-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
68+
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
6869
// In this example, we limit mnist data
69-
(Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
70-
(Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
70+
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
71+
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
7172
}
7273
}
7374
}

test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,17 @@ public class MnistDataSet
1515
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
1616
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
1717
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
18-
1918
public static Datasets read_data_sets(string train_dir,
2019
bool one_hot = false,
2120
TF_DataType dtype = TF_DataType.TF_FLOAT,
2221
bool reshape = true,
2322
int validation_size = 5000,
24-
int test_size = 5000,
23+
int? train_size = null,
24+
int? test_size = null,
2525
string source_url = DEFAULT_SOURCE_URL)
2626
{
27-
var train_size = validation_size * 2;
27+
if (train_size!=null && validation_size >= train_size)
28+
throw new ArgumentException("Validation set should be smaller than training set");
2829

2930
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
3031
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);

test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public void LinearRegression()
5151
[TestMethod]
5252
public void LogisticRegression()
5353
{
54-
new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run();
54+
new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
5555
}
5656

5757
[Ignore]
@@ -78,7 +78,7 @@ public void NamedEntityRecognition()
7878
[TestMethod]
7979
public void NearestNeighbor()
8080
{
81-
new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run();
81+
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
8282
}
8383

8484
[Ignore]

0 commit comments

Comments
 (0)