@@ -546,7 +546,7 @@ define([
546546 name : 'roc_curve' ,
547547 label : 'ROC Curve' ,
548548 import : 'from sklearn import metrics' ,
549- code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function (${roc_featureData}))\n\
549+ code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.predict_proba (${roc_featureData}))\n\
550550plt.plot(fpr, tpr, label='ROC Curve')\n\
551551plt.xlabel('Sensitivity')\n\
552552plt.ylabel('Specificity')\n\
@@ -561,7 +561,7 @@ plt.show()",
561561 name : 'auc' ,
562562 label : 'AUC' ,
563563 import : 'from sklearn import metrics' ,
564- code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function (${auc_featureData}))' ,
564+ code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.predict_proba (${auc_featureData}))' ,
565565 description : '' ,
566566 options : [
567567 { name : 'auc_targetData' , label : 'Target Data' , component : [ 'var_select' ] , var_type : [ 'DataFrame' , 'Series' , 'ndarray' , 'list' , 'dict' ] , value : 'y_test' } ,
@@ -570,6 +570,28 @@ plt.show()",
570570 } ,
571571 'permutation_importance' : defaultInfos [ 'permutation_importance' ]
572572 }
573+
574+ // use decision_function on ROC, AUC
575+ let decisionFunctionTypes = [
576+ 'LogisticRegression' , 'SVC' , 'GradientBoostingClassifier'
577+ ] ;
578+ if ( decisionFunctionTypes . includes ( modelType ) ) {
579+ infos = {
580+ ...infos ,
581+ 'roc_curve' : {
582+ ...infos [ 'roc_curve' ] ,
583+ code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\n\
584+ plt.plot(fpr, tpr, label='ROC Curve')\n\
585+ plt.xlabel('Sensitivity')\n\
586+ plt.ylabel('Specificity')\n\
587+ plt.show()"
588+ } ,
589+ 'auc' : {
590+ ...infos [ 'auc' ] ,
591+ code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function(${auc_featureData}))' ,
592+ }
593+ }
594+ }
573595 break ;
574596 case 'Auto ML' :
575597 infos = {
0 commit comments