Quale funzione di perdita usare per le classi sbilanciate (usando PyTorch)?


17

Ho un set di dati con 3 classi con i seguenti elementi:

  • Classe 1: 900 elementi
  • Classe 2: 15000 elementi
  • Classe 3: 800 elementi

Devo prevedere la classe 1 e la classe 3, che segnalano importanti deviazioni dalla norma. La classe 2 è il caso "normale" predefinito che non mi interessa.

Che tipo di funzione di perdita dovrei usare qui? Stavo pensando di usare CrossEntropyLoss, ma dato che c'è uno squilibrio di classe, questo dovrebbe essere ponderato, suppongo? Come funziona in pratica? Ti piace (usando PyTorch)?

summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)

O il peso dovrebbe essere invertito? cioè 1 / peso?

È questo l'approccio giusto per cominciare o ci sono altri / metodi migliori che potrei usare?

Grazie

Risposte:


13

Che tipo di funzione di perdita dovrei usare qui?

L'entropia incrociata è la funzione di perdita per compiti di classificazione, sia bilanciati che sbilanciati. È la prima scelta quando non viene ancora creata alcuna preferenza dalla conoscenza del dominio.

Questo dovrebbe essere ponderato suppongo? Come funziona in pratica?

Sì. Peso della classec è la dimensione della classe più grande divisa per la dimensione della classe c.

Ad esempio, se la classe 1 ha 900, la classe 2 ha 15000 e la classe 3 ha 800 campioni, i loro pesi sarebbero rispettivamente 16,67, 1,0 e 18,75.

Puoi anche usare la classe più piccola come nominatore, che fornisce rispettivamente 0,889, 0,053 e 1,0. Questo è solo un ridimensionamento, i pesi relativi sono gli stessi.

È questo l'approccio giusto per cominciare o ci sono altri / metodi migliori che potrei usare?

Sì, questo è l'approccio giusto.

MODIFICA :

Grazie a @Muppet, possiamo anche utilizzare un sovracampionamento di classe, che equivale a utilizzare i pesi di classe . Ciò viene realizzato da WeightedRandomSamplerPyTorch, utilizzando gli stessi pesi sopra menzionati.


2
Volevo solo aggiungere che anche l'uso di WeightedRandomSampler di PyTorch mi ha aiutato, nel caso qualcuno lo stia guardando.
Muppet,

0

Quando dici: puoi anche usare la classe più piccola come nominatore, che dà rispettivamente 0,889, 0,053 e 1,0. Questo è solo un ridimensionamento, i pesi relativi sono gli stessi.

Ma questa soluzione è in contraddizione con la prima che hai dato, come funziona?

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.