-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathAdaboost.java
More file actions
169 lines (140 loc) · 5.8 KB
/
Adaboost.java
File metadata and controls
169 lines (140 loc) · 5.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
/**
* This is a decision tree class used in a decision forest implementation. The
* constructor takes the training set and the attributes to use for building
* the tree. Basic algorithm based on R\&N 18.3: Page 702.
*
* http://csmr.ca.sandia.gov/~wpk/pubs/publications/pami06.pdf
*/
import java.io.*;
import java.util.Random;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Collections;
public class Adaboost implements Classifier{
Random random;
DecisionTree[] forest;
public Adaboost(DataSet data, int forestSize) {
random = new Random();
forest = new DecisionTree[forestSize];
/*
* This is basically a nonsensical way of choosing attributes/examples
* to train each tree on. Can't really find much on optimal values.
* Want to make sure the tree works before tweaking, though.
*
* Picks a random number of attributes/examples to train each tree.
*/
ArrayList<Integer> attributes = new ArrayList<Integer>(data.numAttrs);
ArrayList<Integer> examples = new ArrayList<Integer>(data.numTrainExs);
for (int i = 0; i < data.numAttrs; i++) { attributes.add(i); }
for (int i = 0; i < data.numTrainExs; i++) { examples.add(i); }
//Train each tree by choosing a subset of features. Actually just using
//every feature in this case.
int numFeatures = data.numAttrs;
//Train each tree on a sample of 2/3 of the examples
int numTrain = 2 * data.numTrainExs / 3;
for (int cTree = 0; cTree < forestSize; cTree++) {
/* Need to decide how to select number of features*/
//int numFeatures = random.nextInt(data.numAttrs - 1) + 1;
//int numTrain = random.nextInt(data.numTrainExs);
HashSet<Integer> treeAttributes = new HashSet<Integer>(numFeatures);
ArrayList<Integer> treeExamples = new ArrayList<Integer>(numTrain);
//Randomize the list
Collections.shuffle(attributes);
for (int i = 0; i < numFeatures; i++) {
treeAttributes.add(attributes.get(i));
}
Collections.shuffle(examples);
for (int i = 0; i < numTrain; i++) {
treeExamples.add(examples.get(i));
}
//System.out.println(numFeatures + ":" + numTrain);
forest[cTree] = new DecisionTree(data, treeAttributes,
treeExamples, false);
//forest[cTree] = new DecisionTree(data, treeAttributes);
}
}
/*
* Takes the majority vote of the decision trees
*/
public int predict(int[] ex) {
int[] count = new int[2];
for (DecisionTree tree : forest)
count[tree.predict(ex)]++;
return (count[1] > count[0] ? 1 : 0);
}
/** This method should return a very brief but understandable
* description of the learning algorithm that is being used,
* appropriate for posting on the class website.
*/
public String algorithmDescription() {
return "Basic decision forest - uses our DecisionTree";
}
/** This method should return the "author" of this program as you
* would like it to appear on the class website. You can use your
* real name, or a pseudonym, or a name that identifies your
* group.
*/
public String author() {
return "dmrd";
}
/*
* Simple main for testing.
*
*/
public static void main(String argv[])
throws FileNotFoundException, IOException {
if (argv.length < 2) {
System.err.println("argument: filestem forestSize");
return;
}
String filestem = argv[0];
/*
* Create a cross validation set - just takes the last crossSize
* elements of the set as a cross set.
*/
DiscreteDataSet d = new DiscreteDataSet(filestem);
/*
* Do the Knuth Shuffle! It sounds like more fun than it is!
*/
//Set seed to constant to get the same result multiple times
Random random = new Random();
for (int i = 0; i < d.numTrainExs; i++) {
int swap = random.nextInt(d.numTrainExs - i);
int[] tempEx = d.trainEx[swap];
d.trainEx[swap] = d.trainEx[d.numTrainExs - i - 1];
d.trainEx[d.numTrainExs - i - 1] = tempEx;
int tempLabel = d.trainLabel[swap];
d.trainLabel[swap] = d.trainLabel[d.numTrainExs - i - 1];
d.trainLabel[d.numTrainExs - i - 1] = tempLabel;
}
int crossSize = d.numTrainExs/4;
int[][] crossEx = new int[crossSize][];
int[] crossLabel = new int[crossSize];
int[][] dEx = new int[d.numTrainExs - crossSize][];
int[] dLabel = new int[d.numTrainExs - crossSize];
for (int i = 0; i < d.numTrainExs - crossSize; i++) {
dEx[i] = d.trainEx[i];
dLabel[i] = d.trainLabel[i];
}
for (int i = 0; i < crossSize; i++) {
crossEx[i] = d.trainEx[d.numTrainExs - i - 1];
crossLabel[i] = d.trainLabel[d.numTrainExs - i - 1];
}
//Modify original dataset
d.numTrainExs = dEx.length;
d.trainEx = dEx;
d.trainLabel = dLabel;
System.out.println("Training classifier on " + d.numTrainExs
+ " examples");
Classifier c = new DecisionForest(d,Integer.parseInt(argv[1]));
System.out.println("Testing classifier on " + crossEx.length
+ " examples");
int correct = 0;
for (int ex = 0; ex < crossEx.length; ex++) {
if (c.predict(crossEx[ex]) == crossLabel[ex])
correct++;
}
System.out.println("Performance on cross set: "
+ (100*correct / crossEx.length) + "%");
}
}