-
Notifications
You must be signed in to change notification settings - Fork 35
Expand file tree
/
Copy pathDecisionTree.java
More file actions
366 lines (318 loc) · 13.2 KB
/
DecisionTree.java
File metadata and controls
366 lines (318 loc) · 13.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
/**
* This is a decision tree class used in a decision forest implementation. The
* constructor takes the training set and the attributes to use for building
* the tree. Basic algorithm based on R\&N 18.3: Page 702.
*/
import java.io.*;
import java.util.Random;
import java.util.ArrayList;
import java.util.HashSet;
public class DecisionTree implements Classifier{
Node treeRoot;
Random random;
/*
* Randomize tree (for random forest)? If true, chooses some subset of
* attributes to choose best from at each node
*/
boolean randomize;
/*
* How many features do we want to choose from at each node in randomized
* tree?
*/
private int numFeatures(int total) {
return (int)Math.sqrt(total) + 1;
//return (int)Math.log(total) + 1;
}
/*
* Inner class to represent tree structure - Splits on binary features
*/
private class Node {
/*
* What attribute does this branch split on? -1 indicates nothing
*/
public int attribute;
/*
* What is the label for this node?
*/
public int label;
public Node[] children;
/*
* Entropy of boolean variable - n negative examples, p positive
* examples
*/
double entropy(double n, double p) {
if (n == 0 || p == 0)
return 0.0;
return -1.0
* ( ((n/(n+p)) * Math.log(n/(n+p)))
+ ( (p/(n+p)) * Math.log(p/(n+p)))) / Math.log(2);
}
/*
* Pick the most important attribute - want to get the best one to
* split on - see R&N 18.3.4 on page 703. Trying to minimize
* Remainder(A) since B(.) will remain the same for every
* attribute.
*
* Calculates max gain - for a better explanation, see
* http://dms.irb.hr/tutorial/tut_dtrees.php
* http://decisiontrees.net/decision-trees-tutorial/tutorial-4-id3/
*/
int chooseAttribute(DataSet data, HashSet<Integer> attributes,
ArrayList<Integer> examples) {
int bestAttr = -1;
double bestGain = -1;
int[] labelCount = new int[2];
for (int ex : examples) {
//Yay array index beauty...
labelCount[data.trainLabel[ex]]++;
}
double setEntropy = entropy(labelCount[0], labelCount[1]);
for (int attr : attributes) {
//Set to something in the case that there is no gain
//if (bestAttr == -1) { bestAttr = attr; }
/*
* [value][label] : so [0][0] + [0][1] would be the number of
* examples with a value of 0 for the attributes.
* [0][1] + [1][1] would be number of examples with label 1 in
* examples. Use this to calculate entropy and gain.
*/
double[][] count = new double[data.attrVals[attr].length][2];
for (int ex : examples) {
//Yay array index beauty...
count[data.trainEx[ex][attr]][data.trainLabel[ex]]++;
}
//Not needed, but using until I am sure my code works
//(recalculates current entropy every time)
//double gain = entropy(count[0][0] + count[1][0], count[0][1]
//+ count[1][1]);
double gain = setEntropy;
for (int val = 0; val < data.attrVals[attr].length; val++) {
//Get number of examples with this attribute value
gain -= ((count[val][0] + count[val][1]) / examples.size())
* entropy(count[val][0], count[val][1]);
}
if (gain >= bestGain) {
bestAttr = attr;
bestGain = gain;
}
//System.out.println(gain);
}
//System.out.println(bestAttr + "\t" + bestGain + "\t" + setEntropy
//+ "\t" + attributes.size());
return bestAttr;
}
/*
* Accepts the dataset then an array of indices for remaining examples.
* So each integer in examples is just the index of a training example
* in data.trainEx. Avoid making too many copies of the examples this
* way. attributes is a set of all remaining attributes to split on.
* This is modified and passed to children by each node.
*/
Node(DataSet data, HashSet<Integer> attributes,
ArrayList<Integer> examples) {
this.label = -1;
if (examples.size() == 0) {
this.attribute = -1;
this.label = 0; //To avoid crashes, although parent must set
return; // the label on this node
}
/*
* Calculates majority class
*/
int majority = 0;
int count[] = new int[2];
for (int ex : examples) {
count[data.trainLabel[ex]]++;
}
majority = (count[1] > count[0] ? 1 : 0);
/*
* They all have the same label or there are no more attributes to
* split on
*/
if (count[majority] == examples.size() || attributes.size() == 0) {
//System.out.println(count[0] + ":" + count[1]);
this.attribute = -1;
this.label = majority;
return;
}
/*
* If randomization is on (i.e. being used in a random forest), then
* we want to choose some random subset of features to choose best
* split feature in.
*/
if (randomize) {
int numAttr = numFeatures(attributes.size());
//More efficients ways to do this, but this works well enough
HashSet<Integer> attrSample = new HashSet<Integer>(numAttr);
for (int attr : attributes) {
/*
* Add each with a probability of numAttr/number attr
* Also adds some variance so we slightly randomize the
* exact number
*/
if (random.nextInt(attributes.size()) < numAttr) {
attrSample.add(attr);
}
}
//System.out.println(numAttr + " : " + attrSample.size());
this.attribute = chooseAttribute(data, attrSample, examples);
} else {
this.attribute = chooseAttribute(data, attributes, examples);
}
/*No best attribute*/
if (this.attribute == -1) {
this.label = majority;
return;
}
//Remove the attribute so it cannot be used again in child branches
//Add it back in before returning
attributes.remove(this.attribute);
//Initialize list of examples that will be passed to children
ArrayList<ArrayList<Integer>> childExamples = new
ArrayList<ArrayList<Integer>>
(data.attrVals[this.attribute].length);
for (int i = 0; i < data.attrVals[this.attribute].length; i++) {
childExamples.add(new ArrayList<Integer>());
}
//for (ArrayList<Integer> l : childExamples) {
//l = new ArrayList<Integer>();
//}
/*
* Split examples based on the chosen attribute
*/
for (int ex : examples) {
childExamples.get(data.trainEx[ex][this.attribute]).add(ex);
}
children = new Node[data.attrVals[this.attribute].length];
/*Create children trees*/
for (int i = 0; i < data.attrVals[this.attribute].length; i++) {
children[i] = new Node(data,
attributes,
childExamples.get(i));
/*
* Need to set child label if they don't have any examples to train
* on
*/
if (childExamples.get(i).size() == 0) {
children[i].label = majority;
}
}
attributes.add(this.attribute);
}
}
/*Just takes dataset - uses all attributes in training*/
public DecisionTree(DataSet data, boolean rand) {
random = new Random();
this.randomize = rand;
HashSet<Integer> attributes = new HashSet<Integer>(data.numAttrs);
ArrayList<Integer> examples = new ArrayList<Integer>(data.numTrainExs);
/*Initialize example and attribute lists*/
for (int i = 0; i < data.numAttrs; i++) { attributes.add(i); }
for (int i = 0; i < data.numTrainExs; i++) { examples.add(i); }
treeRoot = new Node(data, attributes, examples);
}
/*Takes the dataset and attributes to use in training*/
public DecisionTree(DataSet data, HashSet<Integer> attributes, boolean rand) {
random = new Random();
this.randomize = rand; //Randomized tree?
/*Initialize example lists to include all examples*/
ArrayList<Integer> examples = new ArrayList<Integer>(data.numTrainExs);
for (int i = 0; i < data.numTrainExs; i++) { examples.add(i); }
treeRoot = new Node(data, attributes, examples);
}
/*Take both attributes and examples to use for training*/
public DecisionTree(DataSet data, HashSet<Integer> attributes,
ArrayList<Integer> examples, boolean rand) {
random = new Random();
this.randomize = rand;
treeRoot = new Node(data, attributes, examples);
}
/**
* Walks down the generated tree to return a label for the example.
*/
public int predict(int[] ex) {
Node current = treeRoot;
int depth = 0;
while (current.attribute != -1) {
current = current.children[ex[current.attribute]];
}
//System.out.println(depth);
//System.out.println(current.attribute + "\t" + current.label);
return current.label;
}
/** This method should return a very brief but understandable
* description of the learning algorithm that is being used,
* appropriate for posting on the class website.
*/
public String algorithmDescription() {
return "Basic decision tree for use with random forests";
}
/** This method should return the "author" of this program as you
* would like it to appear on the class website. You can use your
* real name, or a pseudonym, or a name that identifies your
* group.
*/
public String author() {
return "dmrd";
}
/*
* Simple main for testing.
*/
public static void main(String argv[])
throws FileNotFoundException, IOException {
if (argv.length < 1) {
System.err.println("argument: filestem");
return;
}
String filestem = argv[0];
/*
* Create a cross validation set - just takes the last crossSize
* elements of the set as a cross set.
*/
DiscreteDataSet d = new DiscreteDataSet(filestem);
/*
* Do the Knuth Shuffle! It sounds like more fun than it is!
*/
//Set seed to constant to get the same result multiple times
Random random = new Random();
for (int i = 0; i < d.numTrainExs; i++) {
int swap = random.nextInt(d.numTrainExs - i);
int[] tempEx = d.trainEx[swap];
d.trainEx[swap] = d.trainEx[d.numTrainExs - i - 1];
d.trainEx[d.numTrainExs - i - 1] = tempEx;
int tempLabel = d.trainLabel[swap];
d.trainLabel[swap] = d.trainLabel[d.numTrainExs - i - 1];
d.trainLabel[d.numTrainExs - i - 1] = tempLabel;
}
//What proportion of the dataset to use for testing
int crossSize = d.numTrainExs/8;
int[][] crossEx = new int[crossSize][];
int[] crossLabel = new int[crossSize];
int[][] dEx = new int[d.numTrainExs - crossSize][];
int[] dLabel = new int[d.numTrainExs - crossSize];
for (int i = 0; i < d.numTrainExs - crossSize; i++) {
dEx[i] = d.trainEx[i];
dLabel[i] = d.trainLabel[i];
}
for (int i = 0; i < crossSize; i++) {
crossEx[i] = d.trainEx[d.numTrainExs - i - 1];
crossLabel[i] = d.trainLabel[d.numTrainExs - i - 1];
}
//Modify original dataset
d.numTrainExs = dEx.length;
d.trainEx = dEx;
d.trainLabel = dLabel;
System.out.println("Training classifier on " + d.numTrainExs
+ " examples");
Classifier c = new DecisionTree(d, false);
System.out.println("Testing classifier on " + crossEx.length
+ " examples");
int correct = 0;
for (int ex = 0; ex < crossEx.length; ex++) {
if (c.predict(crossEx[ex]) == crossLabel[ex])
correct++;
}
System.out.println("Performance on cross set: "
+ (100*correct / crossEx.length) + "%");
}
}