Scikit Binomial Deviance Loss Function


11

Questa è la funzione di perdita di deviazione binomiale di GradientBoosting di scikit,

   def __call__(self, y, pred, sample_weight=None):
        """Compute the deviance (= 2 * negative log-likelihood). """
        # logaddexp(0, v) == log(1.0 + exp(v))
        pred = pred.ravel()
        if sample_weight is None:
            return -2.0 * np.mean((y * pred) - np.logaddexp(0.0, pred))
        else:
            return (-2.0 / sample_weight.sum() *
                    np.sum(sample_weight * ((y * pred) - np.logaddexp(0.0, pred))))

Questa funzione di perdita non è simile tra la classe con 0 e la classe con 1. Qualcuno può spiegare come questo sia considerato OK.

Ad esempio, senza peso campione, la funzione di perdita per la classe 1 è

-2(pred - log(1 + exp(pred))

vs per la classe 0

-2(-log(1+exp(pred))

La trama per questi due non è simile in termini di costi. Qualcuno può aiutarmi a capire.

Risposte:


17

Sono necessarie due osservazioni per comprendere questa implementazione.

Il primo è che nonpred è una probabilità, è una probabilità di registro.

Psklearnpred-2

ylog(p)+(1-y)log(1-p)=log(1-p)+ylog(p1-p)

p=eP1+eP1-p=11+eP1

log(1-p)=log(11+eP)=-log(1+eP)

e

log(p1-p)=log(eP)=P

Complessivamente, la devianza binomiale è uguale

yP-log(1+eP)

Qual è l'equazione che sklearnsta usando.


Grazie. Se sostituisco predcon le probabilità del registro, la funzione di perdita è uniforme per entrambe le classi.
Kumaran,

Questa stessa domanda mi è venuta di recente. Stavo guardando la pagina gradienteboostedmodels.googlecode.com/git/gbm/inst/doc/gbm.pdf dove è elencato il gradiente della devianza. Ma sembra che il gradiente che mostrano sia per il log-lik non per il log-lik negativo. È corretto - sembra corrispondere alla tua spiegazione qui?
B_Miner

1
@B_Miner il link è interrotto
GeneX

Grazie mille @Matthew Drury
Catbuilt
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.