@@ -41,8 +41,12 @@ define([
4141 confusion_matrix : true , report : true ,
4242 accuracy : false , precision : false , recall : false , f1_score : false ,
4343 roc_curve : false , auc : false ,
44+ model : '' ,
4445 // clustering
45- silhouetteScore : true , ari : false , nm : false ,
46+ clusteredIndex : 'clusters' ,
47+ silhouetteScore : true , ari : false , nmi : false ,
48+ featureData2 : 'X' ,
49+ targetData2 : 'y' ,
4650 ...this . state
4751 }
4852 }
@@ -62,14 +66,23 @@ define([
6266 let modelType = $ ( this ) . val ( ) ;
6367 that . state . modelType = modelType ;
6468
69+ $ ( that . wrapSelector ( '.vp-upper-box' ) ) . hide ( ) ;
70+ $ ( that . wrapSelector ( '.vp-upper-box.' + modelType ) ) . show ( ) ;
71+
6572 $ ( that . wrapSelector ( '.vp-eval-box' ) ) . hide ( ) ;
6673 $ ( that . wrapSelector ( '.vp-eval-' + modelType ) ) . show ( ) ;
6774
68- if ( modelType == 'clf' ) {
75+ if ( modelType == 'rgs' ) {
76+ // Regression
77+
78+ } else if ( modelType == 'clf' ) {
6979 // Classification - model selection
7080 if ( that . checkToShowModel ( ) == true ) {
7181 $ ( that . wrapSelector ( '.vp-ev-model' ) ) . show ( ) ;
7282 }
83+ } else {
84+ // Clustering
85+
7386 }
7487 } ) ;
7588
@@ -118,6 +131,25 @@ define([
118131 varSelector . setValue ( this . state . targetData ) ;
119132 $ ( page ) . find ( '#targetData' ) . replaceWith ( varSelector . toTagString ( ) ) ;
120133
134+ // Clustering - data selection
135+ varSelector = new VarSelector2 ( this . wrapSelector ( ) , [ 'DataFrame' , 'list' , 'str' ] ) ;
136+ varSelector . setComponentID ( 'clusteredIndex' ) ;
137+ varSelector . addClass ( 'vp-state vp-input' ) ;
138+ varSelector . setValue ( this . state . clusteredIndex ) ;
139+ $ ( page ) . find ( '#clusteredIndex' ) . replaceWith ( varSelector . toTagString ( ) ) ;
140+
141+ varSelector = new VarSelector2 ( this . wrapSelector ( ) , [ 'DataFrame' , 'list' , 'str' ] ) ;
142+ varSelector . setComponentID ( 'featureData2' ) ;
143+ varSelector . addClass ( 'vp-state vp-input' ) ;
144+ varSelector . setValue ( this . state . featureData2 ) ;
145+ $ ( page ) . find ( '#featureData2' ) . replaceWith ( varSelector . toTagString ( ) ) ;
146+
147+ varSelector = new VarSelector2 ( this . wrapSelector ( ) , [ 'DataFrame' , 'list' , 'str' ] ) ;
148+ varSelector . setComponentID ( 'targetData2' ) ;
149+ varSelector . addClass ( 'vp-state vp-input' ) ;
150+ varSelector . setValue ( this . state . targetData2 ) ;
151+ $ ( page ) . find ( '#targetData2' ) . replaceWith ( varSelector . toTagString ( ) ) ;
152+
121153 // model
122154 // set model list
123155 let modelOptionTag = new com_String ( ) ;
@@ -169,7 +201,12 @@ define([
169201 }
170202 } ) ;
171203
172- if ( this . state . modelType == 'clf' ) {
204+ $ ( page ) . find ( '.vp-upper-box' ) . hide ( ) ;
205+ $ ( page ) . find ( '.vp-upper-box.' + this . state . modelType ) . show ( ) ;
206+
207+ if ( this . state . modelType == 'rgs' ) {
208+
209+ } else if ( this . state . modelType == 'clf' ) {
173210 if ( this . state . roc_curve == true || this . state . auc == true ) {
174211 $ ( page ) . find ( '.vp-ev-model' ) . show ( ) ;
175212 } else {
@@ -197,7 +234,8 @@ define([
197234 // regression
198235 coefficient, intercept, r_squared, mae, mape, rmse, scatter_plot,
199236 // clustering
200- sizeOfClusters, silhouetteScore, ari, nm
237+ sizeOfClusters, silhouetteScore, ari, nmi,
238+ clusteredIndex, featureData2, targetData2
201239 } = this . state ;
202240
203241 //====================================================================
@@ -317,19 +355,19 @@ define([
317355 if ( silhouetteScore ) {
318356 code = new com_String ( ) ;
319357 code . appendLine ( "# Silhouette score" ) ;
320- code . appendFormat ( "print(f'Silhouette score: {metrics.cluster.silhouette_score({0}, {1})}')" , targetData , predictData ) ;
358+ code . appendFormat ( "print(f'Silhouette score: {metrics.cluster.silhouette_score({0}, {1})}')" , featureData2 , clusteredIndex ) ;
321359 codeCells . push ( code . toString ( ) ) ;
322360 }
323361 if ( ari ) {
324362 code = new com_String ( ) ;
325- code . appendLine ( "# ARI" ) ;
326- code . appendFormat ( "print(f'ARI: {metrics.cluster.adjusted_rand_score({0}, {1})}')" , targetData , predictData ) ;
363+ code . appendLine ( "# ARI(Adjusted Rand score) " ) ;
364+ code . appendFormat ( "print(f'ARI: {metrics.cluster.adjusted_rand_score({0}, {1})}')" , targetData2 , clusteredIndex ) ;
327365 codeCells . push ( code . toString ( ) ) ;
328366 }
329- if ( nm ) {
367+ if ( nmi ) {
330368 code = new com_String ( ) ;
331- code . appendLine ( "# NM " ) ;
332- code . appendFormat ( "print(f'NM: {metrics.cluster.normalized_mutual_info_score({0}, {1})}')" , targetData , predictData ) ;
369+ code . appendLine ( "# NMI(Normalized Mutual Info Score) " ) ;
370+ code . appendFormat ( "print(f'NM: {metrics.cluster.normalized_mutual_info_score({0}, {1})}')" , targetData2 , clusteredIndex ) ;
333371 codeCells . push ( code . toString ( ) ) ;
334372 }
335373 }
0 commit comments