Classificatore con precisione regolabile vs richiamo


11

Sto lavorando a un problema di classificazione binaria in cui è molto più importante non avere falsi positivi; molti falsi negativi vanno bene. Ho usato un sacco di classificatori in sklearn per esempio, ma penso che nessuno di loro abbia la capacità di regolare esplicitamente il compromesso del richiamo di precisione (producono risultati piuttosto buoni ma non regolabili).

Quali classificatori hanno precisione / richiamo regolabili? Esiste un modo per influenzare il compromesso di precisione / richiamo sui classificatori standard, ad esempio Random Forest o AdaBoost?

Risposte:


12

Quasi tutti i classificatori di scikit-learn possono fornire valori di decisione (tramite decision_functiono predict_proba).

Sulla base dei valori decisionali è semplice calcolare curve di richiamo di precisione e / o ROC. scikit-learn fornisce queste funzioni nel suo sottomodulo di metriche .

Un esempio minimo, supponendo che tu abbia datae labelscon contenuti adeguati:

import sklearn.svm
import sklearn.metrics
from matplotlib import pyplot as plt

clf = sklearn.svm.LinearSVC().fit(data, labels)
decision_values = clf.decision_function(data)

precision, recall, thresholds = sklearn.metrics.precision_recall_curve(labels, decision_values)

plt.plot(recall, precision)
plt.show()

Perfetto, grazie! Non sono sicuro di come mi sia perso :)
Alex I

Sembra precision_recall_curvecalcolare l'intera F1. Come calcolare solo quelli negativi?
Mithril,

6

Ho appena risolto questo problema da solo prima di imbattermi in questa Q, quindi ho deciso di condividere la mia soluzione.

Utilizza lo stesso approccio proposto da Marc Claesen, ma risponde alla domanda generale su come regolare il classificatore per spostarsi più in alto sul trading degli assi di precisione fuori dal richiamo.

X_test è i dati e y_test sono le etichette vere. Il classificatore dovrebbe essere già installato.

y_score = clf.decision_function(X_test)

prcsn,rcl,thrshld=precision_recall_curve(y_test,y_score)

min_prcsn=0.25 # here is your precision lower bound e.g. 25%
min_thrshld=min([thrshld[i] for i in range(len(thrshld)) if prcsn[i]>min_prcsn])

E questo è il modo in cui useresti la soglia minima appena imparata per regolare la tua previsione (che altrimenti otterresti semplicemente chiamando predict (X_test))

y_pred_adjusted=[1 if y_s>min_thrshld else 0 for y_s in y_score]

Sarebbe bello sentire il tuo feedback su questa ricetta di regolazione.


Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.