Skip to content

Commit bbdd0ae

Browse files
author
Vadim Pisarevsky
committed
improved tree_engine.cpp sample (added train file data specification; print sorted variable importance table)
1 parent ce474db commit bbdd0ae

File tree

1 file changed

+72
-33
lines changed

1 file changed

+72
-33
lines changed

samples/c/tree_engine.cpp

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "opencv2/ml/ml.hpp"
2+
#include "opencv2/core/core_c.h"
23
#include <stdio.h>
4+
#include <map>
35

46
void help()
57
{
@@ -10,41 +12,81 @@ void help()
1012
"CvRTrees rtrees;\n"
1113
"CvERTrees ertrees;\n"
1214
"CvGBTrees gbtrees;\n"
13-
"Date is hard coded to come from filename = \"../../../opencv/samples/c/waveform.data\";\n"
14-
"Or can come from filename = \"../../../opencv/samples/c/waveform.data\";\n"
15-
"Call:\n"
16-
"./tree_engine\n\n");
15+
"Call:\n\t./tree_engine [-r <response_column>] [-c] <csv filename>\n"
16+
"where -r <response_column> specified the 0-based index of the response (0 by default)\n"
17+
"-c specifies that the response is categorical (it's ordered by default) and\n"
18+
"<csv filename> is the name of training data file in comma-separated value format\n\n");
1719
}
18-
void print_result(float train_err, float test_err, const CvMat* var_imp)
20+
21+
22+
int count_classes(CvMLData& data)
23+
{
24+
cv::Mat r(data.get_responses());
25+
std::map<int, int> rmap;
26+
int i, n = (int)r.total();
27+
for( i = 0; i < n; i++ )
28+
{
29+
float val = r.at<float>(i);
30+
int ival = cvRound(val);
31+
if( ival != val )
32+
return -1;
33+
rmap[ival] = 1;
34+
}
35+
return rmap.size();
36+
}
37+
38+
void print_result(float train_err, float test_err, const CvMat* _var_imp)
1939
{
2040
printf( "train error %f\n", train_err );
2141
printf( "test error %f\n\n", test_err );
2242

23-
if (var_imp)
43+
if (_var_imp)
2444
{
25-
bool is_flt = false;
26-
if ( CV_MAT_TYPE( var_imp->type ) == CV_32FC1)
27-
is_flt = true;
28-
printf( "variable impotance\n" );
29-
for( int i = 0; i < var_imp->cols; i++)
45+
cv::Mat var_imp(_var_imp), sorted_idx;
46+
cv::sortIdx(var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
47+
48+
printf( "variable importance:\n" );
49+
int i, n = (int)var_imp.total();
50+
int type = var_imp.type();
51+
CV_Assert(type == CV_32F || type == CV_64F);
52+
53+
for( i = 0; i < n; i++)
3054
{
31-
printf( "%d %f\n", i, is_flt ? var_imp->data.fl[i] : var_imp->data.db[i] );
55+
int k = sorted_idx.at<int>(i);
56+
printf( "%d\t%f\n", k, type == CV_32F ? var_imp.at<float>(k) : var_imp.at<double>(k));
3257
}
3358
}
3459
printf("\n");
3560
}
3661

37-
int main()
62+
int main(int argc, char** argv)
3863
{
39-
const int train_sample_count = 300;
40-
41-
#define LEPIOTA //Turn on discrete data set
42-
#ifdef LEPIOTA //Of course, you might have to set the path here to what's on your machine ...
43-
const char* filename = "../../opencv/samples/c/agaricus-lepiota.data";
44-
#else
45-
const char* filename = "../../opencv/samples/c/waveform.data";
46-
#endif
47-
printf("\n Reading in %s. If it is not found, you may have to change this hard-coded path in tree_engine.cpp\n\n",filename);
64+
if(argc < 2)
65+
{
66+
help();
67+
return 0;
68+
}
69+
const char* filename = 0;
70+
int response_idx = 0;
71+
bool categorical_response = false;
72+
73+
for(int i = 1; i < argc; i++)
74+
{
75+
if(strcmp(argv[i], "-r") == 0)
76+
sscanf(argv[++i], "%d", &response_idx);
77+
else if(strcmp(argv[i], "-c") == 0)
78+
categorical_response = true;
79+
else if(argv[i][0] != '-' )
80+
filename = argv[i];
81+
else
82+
{
83+
printf("Error. Invalid option %s\n", argv[i]);
84+
help();
85+
return -1;
86+
}
87+
}
88+
89+
printf("\nReading in %s...\n\n",filename);
4890
CvDTree dtree;
4991
CvBoost boost;
5092
CvRTrees rtrees;
@@ -53,29 +95,26 @@ int main()
5395

5496
CvMLData data;
5597

56-
CvTrainTestSplit spl( train_sample_count );
98+
99+
CvTrainTestSplit spl( 0.5f );
57100

58101
if ( data.read_csv( filename ) == 0)
59102
{
60-
61-
#ifdef LEPIOTA
62-
data.set_response_idx( 0 );
63-
#else
64-
data.set_response_idx( 21 );
65-
data.change_var_type( 21, CV_VAR_CATEGORICAL );
66-
#endif
67-
103+
data.set_response_idx( response_idx );
104+
if(categorical_response)
105+
data.change_var_type( response_idx, CV_VAR_CATEGORICAL );
68106
data.set_train_test_split( &spl );
69107

70108
printf("======DTREE=====\n");
71109
dtree.train( &data, CvDTreeParams( 10, 2, 0, false, 16, 0, false, false, 0 ));
72110
print_result( dtree.calc_error( &data, CV_TRAIN_ERROR), dtree.calc_error( &data, CV_TEST_ERROR ), dtree.get_var_importance() );
73111

74-
#ifdef LEPIOTA
112+
if( categorical_response && count_classes(data) == 2 )
113+
{
75114
printf("======BOOST=====\n");
76115
boost.train( &data, CvBoostParams(CvBoost::DISCRETE, 100, 0.95, 2, false, 0));
77116
print_result( boost.calc_error( &data, CV_TRAIN_ERROR ), boost.calc_error( &data, CV_TEST_ERROR ), 0 ); //doesn't compute importance
78-
#endif
117+
}
79118

80119
printf("======RTREES=====\n");
81120
rtrees.train( &data, CvRTParams( 10, 2, 0, false, 16, 0, true, 0, 100, 0, CV_TERMCRIT_ITER ));

0 commit comments

Comments
 (0)