|
13 | 13 | // [CLASS] Evaluation |
14 | 14 | //============================================================================ |
15 | 15 | define([ |
| 16 | + 'text!vp_base/html/m_ml/evaluation.html!strip', |
16 | 17 | 'vp_base/js/com/com_util', |
17 | 18 | 'vp_base/js/com/com_Const', |
18 | 19 | 'vp_base/js/com/com_String', |
19 | 20 | 'vp_base/js/com/component/PopupComponent' |
20 | | -], function(com_util, com_Const, com_String, PopupComponent) { |
| 21 | +], function(evalHTML, com_util, com_Const, com_String, PopupComponent) { |
21 | 22 |
|
22 | 23 | /** |
23 | 24 | * Evaluation |
24 | 25 | */ |
25 | 26 | class Evaluation extends PopupComponent { |
26 | 27 | _init() { |
27 | 28 | super._init(); |
28 | | - /** Write codes executed before rendering */ |
| 29 | + this.config.dataview = false; |
| 30 | + |
| 31 | + this.state = { |
| 32 | + modelType: 'clf', |
| 33 | + predictData: 'pred', |
| 34 | + targetData: 'y_test', |
| 35 | + // classification |
| 36 | + confusion_matrix: true, report: true, |
| 37 | + accuracy: false, precision: false, recall: false, f1_score: false, |
| 38 | + // regression |
| 39 | + coefficient: false, intercept: false, r_squared: true, |
| 40 | + mae: false, mape: false, rmse: true, scatter_plot: false, |
| 41 | + ...this.state |
| 42 | + } |
29 | 43 | } |
30 | 44 |
|
31 | 45 | _bindEvent() { |
32 | 46 | super._bindEvent(); |
33 | 47 | /** Implement binding events */ |
34 | 48 | var that = this; |
35 | | - this.$target.on('click', function(evt) { |
36 | | - var target = evt.target; |
37 | | - if ($(that.wrapSelector()).find(target).length > 0) { |
38 | | - // Sample : getDataList from Kernel |
39 | | - vpKernel.getDataList().then(function(resultObj) { |
40 | | - vpLog.display(VP_LOG_TYPE.DEVELOP, resultObj); |
41 | | - }).catch(function(err) { |
42 | | - vpLog.display(VP_LOG_TYPE.DEVELOP, err); |
43 | | - }); |
44 | | - } |
| 49 | + |
| 50 | + // import library |
| 51 | + $(this.wrapSelector('#vp_importLibrary')).on('click', function() { |
| 52 | + com_interface.insertCell('code', 'from sklearn import metrics'); |
45 | 53 | }); |
| 54 | + |
| 55 | + // model type change |
| 56 | + $(this.wrapSelector('#modelType')).on('change', function() { |
| 57 | + let modelType = $(this).val(); |
| 58 | + that.state.modelType = modelType; |
| 59 | + |
| 60 | + $(page).find('.vp-eval-box').hide(); |
| 61 | + $(page).find('.vp-eval-'+modelType).show(); |
| 62 | + }) |
46 | 63 | } |
47 | 64 |
|
48 | 65 | templateForBody() { |
49 | | - /** Implement generating template */ |
50 | | - return 'This is sample.'; |
| 66 | + let page = $(evalHTML); |
| 67 | + |
| 68 | + $(page).find('.vp-eval-box').hide(); |
| 69 | + $(page).find('.vp-eval-'+this.state.modelType).show(); |
| 70 | + |
| 71 | + return page; |
51 | 72 | } |
52 | 73 |
|
53 | 74 | generateCode() { |
54 | | - return "print('sample code')"; |
| 75 | + let code = new com_String(); |
| 76 | + let { |
| 77 | + modelType, predictData, targetData, |
| 78 | + // classification |
| 79 | + confusion_matrix, report, accuracy, precision, recall, f1_score, |
| 80 | + // regression |
| 81 | + coefficient, intercept, r_squared, mae, mape, rmse, scatter_plot |
| 82 | + } = this.state; |
| 83 | + |
| 84 | + //==================================================================== |
| 85 | + // Classfication |
| 86 | + //==================================================================== |
| 87 | + if (modelType == 'clf') { |
| 88 | + if (confusion_matrix) { |
| 89 | + code.appendLine("# Confusion Matrix"); |
| 90 | + code.appendFormatLine('pd.crosstab({0}, {1}, margins=True)', targetData, predictData); |
| 91 | + } |
| 92 | + if (report) { |
| 93 | + code.appendLine("# Classification report"); |
| 94 | + code.appendFormatLine('print(metrics.classification_report({0}, {1}))', targetData, predictData); |
| 95 | + } |
| 96 | + if (accuracy) { |
| 97 | + code.appendLine("# Accuracy"); |
| 98 | + code.appendFormatLine('metrics.accuracy_score({0}, {1})', targetData, predictData); |
| 99 | + } |
| 100 | + if (precision) { |
| 101 | + code.appendLine("# Precision"); |
| 102 | + code.appendFormatLine("metrics.precision_score({0}, {1}, average='weighted')", targetData, predictData); |
| 103 | + } |
| 104 | + if (recall) { |
| 105 | + code.appendLine("# Recall"); |
| 106 | + code.appendFormatLine("metrics.recall_score({0}, {1}, average='weighted')", targetData, predictData); |
| 107 | + } |
| 108 | + if (f1_score) { |
| 109 | + code.appendLine("# F1-score"); |
| 110 | + code.appendFormatLine("metrics.f1_score({0}, {1}, average='weighted')", targetData, predictData); |
| 111 | + } |
| 112 | + } |
| 113 | + |
| 114 | + //==================================================================== |
| 115 | + // Regression |
| 116 | + //==================================================================== |
| 117 | + if (modelType == 'rgs') { |
| 118 | + if (coefficient) { |
| 119 | + code.appendLine("# Coefficient (scikit-learn only)"); |
| 120 | + code.appendFormatLine('model.coef_'); |
| 121 | + } |
| 122 | + if (intercept) { |
| 123 | + code.appendLine("# Intercept (scikit-learn only)"); |
| 124 | + code.appendFormatLine('model.intercept_'); |
| 125 | + } |
| 126 | + if (r_squared) { |
| 127 | + code.appendLine("# R square"); |
| 128 | + code.appendFormatLine('metrics.r2_score({0}, {1})', targetData, predictData); |
| 129 | + } |
| 130 | + if (mae) { |
| 131 | + code.appendLine("# MAE(Mean Absolute Error)"); |
| 132 | + code.appendFormatLine('metrics.mean_absolute_error({0}, {1})', targetData, predictData); |
| 133 | + } |
| 134 | + if (mape) { |
| 135 | + code.appendLine("# MAPE(Mean Absolute Percentage Error)"); |
| 136 | + code.appendLine('def MAPE(y_test, y_pred):'); |
| 137 | + code.appendLine(' return np.mean(np.abs((y_test - pred) / y_test)) * 100'); |
| 138 | + code.appendLine(); |
| 139 | + code.appendFormatLine('MAPE({0}, {1})', targetData, predictData); |
| 140 | + } |
| 141 | + if (rmse) { |
| 142 | + code.appendLine("# RMSE(Root Mean Squared Error)"); |
| 143 | + code.appendFormatLine('metrics.mean_squared_error({0}, {1})**0.5', targetData, predictData); |
| 144 | + } |
| 145 | + if (scatter_plot) { |
| 146 | + code.appendLine('# Regression plot'); |
| 147 | + code.appendFormatLine('plt.scatter({0}, {1})', targetData, predictData); |
| 148 | + code.appendFormatLine("plt.xlabel('{0}')", targetData); |
| 149 | + code.appendFormatLine("plt.ylabel('{1}')", predictData); |
| 150 | + code.appendLine('plt.show()'); |
| 151 | + } |
| 152 | + } |
| 153 | + // FIXME: as seperated cells |
| 154 | + return code.toString(); |
55 | 155 | } |
56 | 156 |
|
57 | 157 | } |
|
0 commit comments