@@ -118,6 +118,7 @@ define([
118118 }
119119
120120 generateCode ( ) {
121+ let codeCells = [ ] ;
121122 let code = new com_String ( ) ;
122123 let {
123124 modelType, predictData, targetData,
@@ -134,40 +135,56 @@ define([
134135 //====================================================================
135136 if ( modelType == 'clf' ) {
136137 if ( confusion_matrix ) {
138+ code = new com_String ( ) ;
137139 code . appendLine ( "# Confusion Matrix" ) ;
138- code . appendFormatLine ( 'pd.crosstab({0}, {1}, margins=True)' , targetData , predictData ) ;
140+ code . appendFormat ( 'pd.crosstab({0}, {1}, margins=True)' , targetData , predictData ) ;
141+ codeCells . push ( code . toString ( ) ) ;
139142 }
140143 if ( report ) {
144+ code = new com_String ( ) ;
141145 code . appendLine ( "# Classification report" ) ;
142- code . appendFormatLine ( 'print(metrics.classification_report({0}, {1}))' , targetData , predictData ) ;
146+ code . appendFormat ( 'print(metrics.classification_report({0}, {1}))' , targetData , predictData ) ;
147+ codeCells . push ( code . toString ( ) ) ;
143148 }
144149 if ( accuracy ) {
150+ code = new com_String ( ) ;
145151 code . appendLine ( "# Accuracy" ) ;
146- code . appendFormatLine ( 'metrics.accuracy_score({0}, {1})' , targetData , predictData ) ;
152+ code . appendFormat ( 'metrics.accuracy_score({0}, {1})' , targetData , predictData ) ;
153+ codeCells . push ( code . toString ( ) ) ;
147154 }
148155 if ( precision ) {
156+ code = new com_String ( ) ;
149157 code . appendLine ( "# Precision" ) ;
150- code . appendFormatLine ( "metrics.precision_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
158+ code . appendFormat ( "metrics.precision_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
159+ codeCells . push ( code . toString ( ) ) ;
151160 }
152161 if ( recall ) {
162+ code = new com_String ( ) ;
153163 code . appendLine ( "# Recall" ) ;
154- code . appendFormatLine ( "metrics.recall_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
164+ code . appendFormat ( "metrics.recall_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
165+ codeCells . push ( code . toString ( ) ) ;
155166 }
156167 if ( f1_score ) {
168+ code = new com_String ( ) ;
157169 code . appendLine ( "# F1-score" ) ;
158- code . appendFormatLine ( "metrics.f1_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
170+ code . appendFormat ( "metrics.f1_score({0}, {1}, average='weighted')" , targetData , predictData ) ;
171+ codeCells . push ( code . toString ( ) ) ;
159172 }
160173 if ( roc_curve ) {
174+ code = new com_String ( ) ;
161175 code . appendLine ( "# ROC Curve" ) ;
162176 code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))" , predictData , targetData ) ;
163177 code . appendLine ( "plt.plot(fpr, tpr, label='ROC Curve')" ) ;
164178 code . appendLine ( "plt.xlabel('Sensitivity') " ) ;
165- code . appendLine ( "plt.ylabel('Specificity') " )
179+ code . append ( "plt.ylabel('Specificity') " )
180+ codeCells . push ( code . toString ( ) ) ;
166181 }
167182 if ( auc ) {
183+ code = new com_String ( ) ;
168184 code . appendLine ( "# AUC" ) ;
169185 code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))" , predictData , targetData ) ;
170- code . appendLine ( "metrics.auc(fpr, tpr)" ) ;
186+ code . append ( "metrics.auc(fpr, tpr)" ) ;
187+ codeCells . push ( code . toString ( ) ) ;
171188 }
172189 }
173190
@@ -184,30 +201,40 @@ define([
184201 // code.appendFormatLine('model.intercept_');
185202 // }
186203 if ( r_squared ) {
204+ code = new com_String ( ) ;
187205 code . appendLine ( "# R square" ) ;
188- code . appendFormatLine ( 'metrics.r2_score({0}, {1})' , targetData , predictData ) ;
206+ code . appendFormat ( 'metrics.r2_score({0}, {1})' , targetData , predictData ) ;
207+ codeCells . push ( code . toString ( ) ) ;
189208 }
190209 if ( mae ) {
210+ code = new com_String ( ) ;
191211 code . appendLine ( "# MAE(Mean Absolute Error)" ) ;
192- code . appendFormatLine ( 'metrics.mean_absolute_error({0}, {1})' , targetData , predictData ) ;
212+ code . appendFormat ( 'metrics.mean_absolute_error({0}, {1})' , targetData , predictData ) ;
213+ codeCells . push ( code . toString ( ) ) ;
193214 }
194215 if ( mape ) {
216+ code = new com_String ( ) ;
195217 code . appendLine ( "# MAPE(Mean Absolute Percentage Error)" ) ;
196218 code . appendLine ( 'def MAPE(y_test, y_pred):' ) ;
197219 code . appendLine ( ' return np.mean(np.abs((y_test - pred) / y_test)) * 100' ) ;
198220 code . appendLine ( ) ;
199- code . appendFormatLine ( 'MAPE({0}, {1})' , targetData , predictData ) ;
221+ code . appendFormat ( 'MAPE({0}, {1})' , targetData , predictData ) ;
222+ codeCells . push ( code . toString ( ) ) ;
200223 }
201224 if ( rmse ) {
225+ code = new com_String ( ) ;
202226 code . appendLine ( "# RMSE(Root Mean Squared Error)" ) ;
203- code . appendFormatLine ( 'metrics.mean_squared_error({0}, {1})**0.5' , targetData , predictData ) ;
227+ code . appendFormat ( 'metrics.mean_squared_error({0}, {1})**0.5' , targetData , predictData ) ;
228+ codeCells . push ( code . toString ( ) ) ;
204229 }
205230 if ( scatter_plot ) {
231+ code = new com_String ( ) ;
206232 code . appendLine ( '# Regression plot' ) ;
207233 code . appendFormatLine ( 'plt.scatter({0}, {1})' , targetData , predictData ) ;
208234 code . appendFormatLine ( "plt.xlabel('{0}')" , targetData ) ;
209235 code . appendFormatLine ( "plt.ylabel('{1}')" , predictData ) ;
210- code . appendLine ( 'plt.show()' ) ;
236+ code . append ( 'plt.show()' ) ;
237+ codeCells . push ( code . toString ( ) ) ;
211238 }
212239 }
213240 //====================================================================
@@ -219,20 +246,26 @@ define([
219246 // code.appendFormatLine("print(f'Size of clusters: {np.bincount({0})}')", predictData);
220247 // }
221248 if ( silhouetteScore ) {
249+ code = new com_String ( ) ;
222250 code . appendLine ( "# Silhouette score" ) ;
223- code . appendFormatLine ( "print(f'Silhouette score: {metrics.cluster.silhouette_score({0}, {1})}')" , targetData , predictData ) ;
251+ code . appendFormat ( "print(f'Silhouette score: {metrics.cluster.silhouette_score({0}, {1})}')" , targetData , predictData ) ;
252+ codeCells . push ( code . toString ( ) ) ;
224253 }
225254 if ( ari ) {
255+ code = new com_String ( ) ;
226256 code . appendLine ( "# ARI" ) ;
227- code . appendFormatLine ( "print(f'ARI: {metrics.cluster.adjusted_rand_score({0}, {1})}')" , targetData , predictData ) ;
257+ code . appendFormat ( "print(f'ARI: {metrics.cluster.adjusted_rand_score({0}, {1})}')" , targetData , predictData ) ;
258+ codeCells . push ( code . toString ( ) ) ;
228259 }
229260 if ( nm ) {
261+ code = new com_String ( ) ;
230262 code . appendLine ( "# NM" ) ;
231- code . appendFormatLine ( "print(f'NM: {metrics.cluster.normalized_mutual_info_score({0}, {1})}')" , targetData , predictData ) ;
263+ code . appendFormat ( "print(f'NM: {metrics.cluster.normalized_mutual_info_score({0}, {1})}')" , targetData , predictData ) ;
264+ codeCells . push ( code . toString ( ) ) ;
232265 }
233266 }
234- // FIXME: as seperated cells
235- return code . toString ( ) ;
267+ // return as seperated cells
268+ return codeCells ;
236269 }
237270
238271 }
0 commit comments