@@ -28,20 +28,21 @@ define([
2828 class Evaluation extends PopupComponent {
2929 _init ( ) {
3030 super . _init ( ) ;
31+ this . config . importButton = true ;
3132 this . config . dataview = false ;
3233
3334 this . state = {
3435 modelType : 'rgs' ,
3536 predictData : 'pred' ,
3637 targetData : 'y_test' ,
38+ // regression
39+ r_squared : true , mae : true , mape : false , rmse : true , scatter_plot : false ,
3740 // classification
3841 confusion_matrix : true , report : true ,
3942 accuracy : false , precision : false , recall : false , f1_score : false ,
40- // regression
41- coefficient : false , intercept : false , r_squared : true ,
42- mae : false , mape : false , rmse : true , scatter_plot : false ,
43+ roc_curve : false , auc : false ,
4344 // clustering
44- sizeOfClusters : true , silhouetteScore : true ,
45+ silhouetteScore : true , ari : false , nm : false ,
4546 ...this . state
4647 }
4748 }
@@ -63,7 +64,39 @@ define([
6364
6465 $ ( that . wrapSelector ( '.vp-eval-box' ) ) . hide ( ) ;
6566 $ ( that . wrapSelector ( '.vp-eval-' + modelType ) ) . show ( ) ;
66- } )
67+
68+ if ( modelType == 'clf' ) {
69+ // Classification - model selection
70+ if ( that . checkToShowModel ( ) == true ) {
71+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . show ( ) ;
72+ }
73+ }
74+ } ) ;
75+
76+ // open model selection show
77+ $ ( this . wrapSelector ( '.vp-eval-check' ) ) . on ( 'change' , function ( ) {
78+ let checked = $ ( this ) . prop ( 'checked' ) ;
79+
80+ if ( checked ) {
81+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . show ( ) ;
82+ } else {
83+ if ( that . checkToShowModel ( ) == false ) {
84+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . hide ( ) ;
85+ }
86+ }
87+ } ) ;
88+ }
89+
90+ /**
91+ * Check if anything checked available ( > 0)
92+ * @returns
93+ */
94+ checkToShowModel ( ) {
95+ let checked = $ ( this . wrapSelector ( '.vp-eval-check:checked' ) ) . length ;
96+ if ( checked > 0 ) {
97+ return true ;
98+ }
99+ return false ;
67100 }
68101
69102 templateForBody ( ) {
@@ -72,7 +105,7 @@ define([
72105 $ ( page ) . find ( '.vp-eval-box' ) . hide ( ) ;
73106 $ ( page ) . find ( '.vp-eval-' + this . state . modelType ) . show ( ) ;
74107
75- // varselector TEST:
108+ // varselector
76109 let varSelector = new VarSelector2 ( this . wrapSelector ( ) , [ 'DataFrame' , 'list' , 'str' ] ) ;
77110 varSelector . setComponentID ( 'predictData' ) ;
78111 varSelector . addClass ( 'vp-state vp-input' ) ;
@@ -85,6 +118,28 @@ define([
85118 varSelector . setValue ( this . state . targetData ) ;
86119 $ ( page ) . find ( '#targetData' ) . replaceWith ( varSelector . toTagString ( ) ) ;
87120
121+ // model
122+ // set model list
123+ let modelOptionTag = new com_String ( ) ;
124+ vpKernel . getModelList ( 'Classification' ) . then ( function ( resultObj ) {
125+ let { result } = resultObj ;
126+ var modelList = JSON . parse ( result ) ;
127+ modelList && modelList . forEach ( model => {
128+ let selectFlag = '' ;
129+ if ( model . varName == that . state . model ) {
130+ selectFlag = 'selected' ;
131+ }
132+ modelOptionTag . appendFormatLine ( '<option value="{0}" data-type="{1}" {2}>{3} ({4})</option>' ,
133+ model . varName , model . varType , selectFlag , model . varName , model . varType ) ;
134+ } ) ;
135+ $ ( page ) . find ( '#model' ) . html ( modelOptionTag . toString ( ) ) ;
136+ $ ( that . wrapSelector ( '#model' ) ) . html ( modelOptionTag . toString ( ) ) ;
137+
138+ if ( ! that . state . model || that . state . model == '' ) {
139+ that . state . model = $ ( that . wrapSelector ( '#model' ) ) . val ( ) ;
140+ }
141+ } ) ;
142+
88143 // load state
89144 let that = this ;
90145 Object . keys ( this . state ) . forEach ( key => {
@@ -114,8 +169,22 @@ define([
114169 }
115170 } ) ;
116171
172+ if ( this . state . modelType == 'clf' ) {
173+ if ( this . state . roc_curve == true || this . state . auc == true ) {
174+ $ ( page ) . find ( '.vp-ev-model' ) . show ( ) ;
175+ } else {
176+ $ ( page ) . find ( '.vp-ev-model' ) . hide ( ) ;
177+ }
178+ } else {
179+ $ ( page ) . find ( '.vp-ev-model' ) . hide ( ) ;
180+ }
181+
117182 return page ;
118183 }
184+
185+ generateImportCode ( ) {
186+ return 'from sklearn import metrics' ;
187+ }
119188
120189 generateCode ( ) {
121190 let codeCells = [ ] ;
@@ -124,6 +193,7 @@ define([
124193 modelType, predictData, targetData,
125194 // classification
126195 confusion_matrix, report, accuracy, precision, recall, f1_score, roc_curve, auc,
196+ model,
127197 // regression
128198 coefficient, intercept, r_squared, mae, mape, rmse, scatter_plot,
129199 // clustering
@@ -173,7 +243,7 @@ define([
173243 if ( roc_curve ) {
174244 code = new com_String ( ) ;
175245 code . appendLine ( "# ROC Curve" ) ;
176- code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc .decision_function({1}} ))" , predictData , targetData ) ;
246+ code . appendFormatLine ( "fpr, tpr, thresholds = metrics. roc_curve({0}, {1} .decision_function({2} ))" , predictData , model , targetData ) ;
177247 code . appendLine ( "plt.plot(fpr, tpr, label='ROC Curve')" ) ;
178248 code . appendLine ( "plt.xlabel('Sensitivity') " ) ;
179249 code . append ( "plt.ylabel('Specificity') " )
@@ -182,8 +252,7 @@ define([
182252 if ( auc ) {
183253 code = new com_String ( ) ;
184254 code . appendLine ( "# AUC" ) ;
185- code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))" , predictData , targetData ) ;
186- code . append ( "metrics.auc(fpr, tpr)" ) ;
255+ code . appendFormat ( "metrics.roc_auc_score({0}, {1}.decision_function({2}))" , predictData , model , targetData ) ;
187256 codeCells . push ( code . toString ( ) ) ;
188257 }
189258 }
@@ -232,7 +301,7 @@ define([
232301 code . appendLine ( '# Regression plot' ) ;
233302 code . appendFormatLine ( 'plt.scatter({0}, {1})' , targetData , predictData ) ;
234303 code . appendFormatLine ( "plt.xlabel('{0}')" , targetData ) ;
235- code . appendFormatLine ( "plt.ylabel('{1 }')" , predictData ) ;
304+ code . appendFormatLine ( "plt.ylabel('{0 }')" , predictData ) ;
236305 code . append ( 'plt.show()' ) ;
237306 codeCells . push ( code . toString ( ) ) ;
238307 }
0 commit comments