-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathBaselineClassifier.java
More file actions
75 lines (58 loc) · 2.06 KB
/
BaselineClassifier.java
File metadata and controls
75 lines (58 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import java.io.*;
/**
* This is the class for an extremely simple learning algorithm that
* finds the most frequent class in the training data, and then
* predicts that each new test example belongs to this class.
*/
public class BaselineClassifier implements Classifier {
private int most_frequent_class;
private String author = "Rob Schapire";
private String description = "A very simple learning algorithm that, "
+ "on each test example, predicts with the most frequent class seen "
+ "during training";
/**
* This constructor takes as input a dataset and computes and
* stores the most frequent class
*/
public BaselineClassifier(DataSet d) {
int count[] = new int[2];
for (int i = 0; i < d.numTrainExs; i++)
count[d.trainLabel[i]]++;
most_frequent_class = (count[1] > count[0] ? 1 : 0);
}
/** The prediction method ignores the given example and predicts
* with the most frequent class seen during training.
*/
public int predict(int[] ex) {
return most_frequent_class;
}
/** This method returns a description of the learning algorithm. */
public String algorithmDescription() {
return description;
}
/** This method returns the author of this program. */
public String author() {
return author;
}
/** A simple main for testing this algorithm. This main reads a
* filestem from the command line, runs the learning algorithm on
* this dataset, and prints the test predictions to filestem.testout.
*/
public static void main(String argv[])
throws FileNotFoundException, IOException {
if (argv.length < 1) {
System.err.println("argument: filestem");
return;
}
String filestem = argv[0];
DataSet d = new DataSet(filestem);
Classifier c = new BaselineClassifier(d);
int correct = 0;
for (int i = 0; i < d.numTrainExs; i++) {
if (c.predict(d.trainEx[i]) == d.trainLabel[i])
correct++;
}
System.out.println((100*correct/d.numTrainExs) + "%");
d.printTestPredictions(c, filestem);
}
}