Ho un set di dati con 3 classi con i seguenti elementi:
- Classe 1: 900 elementi
- Classe 2: 15000 elementi
- Classe 3: 800 elementi
Devo prevedere la classe 1 e la classe 3, che segnalano importanti deviazioni dalla norma. La classe 2 è il caso "normale" predefinito che non mi interessa.
Che tipo di funzione di perdita dovrei usare qui? Stavo pensando di usare CrossEntropyLoss, ma dato che c'è uno squilibrio di classe, questo dovrebbe essere ponderato, suppongo? Come funziona in pratica? Ti piace (usando PyTorch)?
summed = 900 + 15000 + 800
weight = torch.tensor([900, 15000, 800]) / summed
crit = nn.CrossEntropyLoss(weight=weight)
O il peso dovrebbe essere invertito? cioè 1 / peso?
È questo l'approccio giusto per cominciare o ci sono altri / metodi migliori che potrei usare?
Grazie