11#include " opencv2/ml/ml.hpp"
2+ #include " opencv2/core/core_c.h"
23#include < stdio.h>
4+ #include < map>
35
46void help ()
57{
@@ -10,41 +12,81 @@ void help()
1012 " CvRTrees rtrees;\n "
1113 " CvERTrees ertrees;\n "
1214 " CvGBTrees gbtrees;\n "
13- " Date is hard coded to come from filename = \" ../../../opencv/samples/c/waveform.data \" ; \n "
14- " Or can come from filename = \" ../../../opencv/samples/c/waveform.data \" ; \n "
15- " Call: \n "
16- " ./tree_engine \n\n " );
15+ " Call: \n\t ./tree_engine [-r <response_column>] [-c] <csv filename> \n "
16+ " where -r <response_column> specified the 0-based index of the response (0 by default) \n "
17+ " -c specifies that the response is categorical (it's ordered by default) and \n "
18+ " <csv filename> is the name of training data file in comma-separated value format \n\n " );
1719}
18- void print_result (float train_err, float test_err, const CvMat* var_imp)
20+
21+
22+ int count_classes (CvMLData& data)
23+ {
24+ cv::Mat r (data.get_responses ());
25+ std::map<int , int > rmap;
26+ int i, n = (int )r.total ();
27+ for ( i = 0 ; i < n; i++ )
28+ {
29+ float val = r.at <float >(i);
30+ int ival = cvRound (val);
31+ if ( ival != val )
32+ return -1 ;
33+ rmap[ival] = 1 ;
34+ }
35+ return rmap.size ();
36+ }
37+
38+ void print_result (float train_err, float test_err, const CvMat* _var_imp)
1939{
2040 printf ( " train error %f\n " , train_err );
2141 printf ( " test error %f\n\n " , test_err );
2242
23- if (var_imp )
43+ if (_var_imp )
2444 {
25- bool is_flt = false ;
26- if ( CV_MAT_TYPE ( var_imp->type ) == CV_32FC1)
27- is_flt = true ;
28- printf ( " variable impotance\n " );
29- for ( int i = 0 ; i < var_imp->cols ; i++)
45+ cv::Mat var_imp (_var_imp), sorted_idx;
46+ cv::sortIdx (var_imp, sorted_idx, CV_SORT_EVERY_ROW + CV_SORT_DESCENDING);
47+
48+ printf ( " variable importance:\n " );
49+ int i, n = (int )var_imp.total ();
50+ int type = var_imp.type ();
51+ CV_Assert (type == CV_32F || type == CV_64F);
52+
53+ for ( i = 0 ; i < n; i++)
3054 {
31- printf ( " %d %f\n " , i, is_flt ? var_imp->data .fl [i] : var_imp->data .db [i] );
55+ int k = sorted_idx.at <int >(i);
56+ printf ( " %d\t %f\n " , k, type == CV_32F ? var_imp.at <float >(k) : var_imp.at <double >(k));
3257 }
3358 }
3459 printf (" \n " );
3560}
3661
37- int main ()
62+ int main (int argc, char ** argv )
3863{
39- const int train_sample_count = 300 ;
40-
41- #define LEPIOTA // Turn on discrete data set
42- #ifdef LEPIOTA // Of course, you might have to set the path here to what's on your machine ...
43- const char * filename = " ../../opencv/samples/c/agaricus-lepiota.data" ;
44- #else
45- const char * filename = " ../../opencv/samples/c/waveform.data" ;
46- #endif
47- printf (" \n Reading in %s. If it is not found, you may have to change this hard-coded path in tree_engine.cpp\n\n " ,filename);
64+ if (argc < 2 )
65+ {
66+ help ();
67+ return 0 ;
68+ }
69+ const char * filename = 0 ;
70+ int response_idx = 0 ;
71+ bool categorical_response = false ;
72+
73+ for (int i = 1 ; i < argc; i++)
74+ {
75+ if (strcmp (argv[i], " -r" ) == 0 )
76+ sscanf (argv[++i], " %d" , &response_idx);
77+ else if (strcmp (argv[i], " -c" ) == 0 )
78+ categorical_response = true ;
79+ else if (argv[i][0 ] != ' -' )
80+ filename = argv[i];
81+ else
82+ {
83+ printf (" Error. Invalid option %s\n " , argv[i]);
84+ help ();
85+ return -1 ;
86+ }
87+ }
88+
89+ printf (" \n Reading in %s...\n\n " ,filename);
4890 CvDTree dtree;
4991 CvBoost boost;
5092 CvRTrees rtrees;
@@ -53,29 +95,26 @@ int main()
5395
5496 CvMLData data;
5597
56- CvTrainTestSplit spl ( train_sample_count );
98+
99+ CvTrainTestSplit spl ( 0 .5f );
57100
58101 if ( data.read_csv ( filename ) == 0 )
59102 {
60-
61- #ifdef LEPIOTA
62- data.set_response_idx ( 0 );
63- #else
64- data.set_response_idx ( 21 );
65- data.change_var_type ( 21 , CV_VAR_CATEGORICAL );
66- #endif
67-
103+ data.set_response_idx ( response_idx );
104+ if (categorical_response)
105+ data.change_var_type ( response_idx, CV_VAR_CATEGORICAL );
68106 data.set_train_test_split ( &spl );
69107
70108 printf (" ======DTREE=====\n " );
71109 dtree.train ( &data, CvDTreeParams ( 10 , 2 , 0 , false , 16 , 0 , false , false , 0 ));
72110 print_result ( dtree.calc_error ( &data, CV_TRAIN_ERROR), dtree.calc_error ( &data, CV_TEST_ERROR ), dtree.get_var_importance () );
73111
74- #ifdef LEPIOTA
112+ if ( categorical_response && count_classes (data) == 2 )
113+ {
75114 printf (" ======BOOST=====\n " );
76115 boost.train ( &data, CvBoostParams (CvBoost::DISCRETE, 100 , 0.95 , 2 , false , 0 ));
77116 print_result ( boost.calc_error ( &data, CV_TRAIN_ERROR ), boost.calc_error ( &data, CV_TEST_ERROR ), 0 ); // doesn't compute importance
78- # endif
117+ }
79118
80119 printf (" ======RTREES=====\n " );
81120 rtrees.train ( &data, CvRTParams ( 10 , 2 , 0 , false , 16 , 0 , true , 0 , 100 , 0 , CV_TERMCRIT_ITER ));
0 commit comments