@@ -91,39 +91,42 @@ def predict(self, samples):
9191
9292
9393if __name__ == '__main__' :
94- import argparse
94+ import getopt
95+ import sys
9596
9697 models = [RTrees , KNearest , Boost , SVM ] # MLP, NBayes
9798 models = dict ( [(cls .__name__ .lower (), cls ) for cls in models ] )
99+
100+ print 'USAGE: letter_recog.py [--model <model>] [--data <data fn>] [--load <model fn>] [--save <model fn>]'
101+ print 'Models: ' , ', ' .join (models )
102+ print
98103
99- parser = argparse .ArgumentParser ()
100- parser .add_argument ('-model' , default = 'rtrees' , choices = models .keys ())
101- parser .add_argument ('-data' , nargs = 1 , default = '../cpp/letter-recognition.data' )
102- parser .add_argument ('-load' , nargs = 1 )
103- parser .add_argument ('-save' , nargs = 1 )
104- args = parser .parse_args ()
105-
106- print 'loading data %s ...' % args .data
107- samples , responses = load_base (args .data )
108- Model = models [args .model ]
104+ args , dummy = getopt .getopt (sys .argv [1 :], '' , ['model=' , 'data=' , 'load=' , 'save=' ])
105+ args = dict (args )
106+ args .setdefault ('--model' , 'rtrees' )
107+ args .setdefault ('--data' , '../cpp/letter-recognition.data' )
108+
109+ print 'loading data %s ...' % args ['--data' ]
110+ samples , responses = load_base (args ['--data' ])
111+ Model = models [args ['--model' ]]
109112 model = Model ()
110113
111114 train_n = int (len (samples )* model .train_ratio )
112- if args .load is None :
113- print 'training %s ...' % Model .__name__
114- model .train (samples [:train_n ], responses [:train_n ])
115- else :
116- fn = args .load [0 ]
115+ if '--load' in args :
116+ fn = args ['--load' ]
117117 print 'loading model from %s ...' % fn
118118 model .load (fn )
119+ else :
120+ print 'training %s ...' % Model .__name__
121+ model .train (samples [:train_n ], responses [:train_n ])
119122
120123 print 'testing...'
121124 train_rate = np .mean (model .predict (samples [:train_n ]) == responses [:train_n ])
122125 test_rate = np .mean (model .predict (samples [train_n :]) == responses [train_n :])
123126
124127 print 'train rate: %f test rate: %f' % (train_rate * 100 , test_rate * 100 )
125128
126- if args . save is not None :
127- fn = args . save [ 0 ]
129+ if '-- save' in args :
130+ fn = args [ '--save' ]
128131 print 'saving model to %s ...' % fn
129132 model .save (fn )
0 commit comments