Architetture della CNN per regressione?


32

Ho lavorato su un problema di regressione in cui l'input è un'immagine e l'etichetta ha un valore continuo tra 80 e 350. Le immagini sono di alcuni prodotti chimici dopo una reazione. Il colore che risulta indica la concentrazione di un'altra sostanza chimica rimasta, ed è quello che il modello deve produrre: la concentrazione di quella sostanza chimica. Le immagini possono essere ruotate, capovolte, specchiate e l'output previsto dovrebbe essere sempre lo stesso. Questo tipo di analisi viene eseguita in veri e propri laboratori (macchine molto specializzate generano la concentrazione delle sostanze chimiche usando l'analisi del colore proprio come sto addestrando questo modello a fare).

Finora ho sperimentato solo modelli basati approssimativamente su VGG (sequenze multiple di blocchi conv-conv-conv-pool). Prima di sperimentare architetture più recenti (Inception, ResNets, ecc.), Ho pensato di ricercare se ci sono altre architetture più comunemente usate per la regressione usando le immagini.

Il set di dati è simile al seguente:

inserisci qui la descrizione dell'immagine

Il set di dati contiene circa 5.000 campioni 250x250, che ho ridimensionato a 64x64, quindi l'addestramento è più semplice. Una volta trovata un'architettura promettente, sperimenterò immagini con risoluzione maggiore.

Finora, i miei migliori modelli hanno un errore quadratico medio su entrambi i set di addestramento e di validazione di circa 0,3, che è tutt'altro che accettabile nel mio caso d'uso.

Finora il mio miglior modello è simile al seguente:

// pseudo code
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=32, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=64, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = conv2d(x, filters=128, kernel=[3,3])->batch_norm()->relu()
x = maxpool(x, size=[2,2], stride=[2,2])

x = dropout()->conv2d(x, filters=128, kernel=[1, 1])->batch_norm()->relu()
x = dropout()->conv2d(x, filters=32, kernel=[1, 1])->batch_norm()->relu()

y = dense(x, units=1)

// loss = mean_squared_error(y, labels)

Domanda

Qual è un'architettura appropriata per l'output di regressione da un input di immagine?

modificare

Ho riformulato la mia spiegazione e rimosso le menzioni di accuratezza.

Modifica 2

Ho ristrutturato la mia domanda, quindi spero sia chiaro cosa sto cercando


4
La precisione non è una misura che può essere applicata direttamente ai problemi di regressione. Cosa intendi quando dici che la tua precisione è del 30%? La precisione si applica solo alle attività di classificazione, non alla regressione.
Nuclear Wang,

1
Cosa intendi con "prevede correttamente il 30% delle volte" ? Stai davvero facendo regressione?
Firebug

1
Perché si chiama questa regressione del problema? Non stai cercando di classificare in etichette? sono le etichette cardinali?
Aksakal,

2
Non voglio la stessa identica cosa di vgg. Sto facendo qualcosa di simile a vgg, ovvero una serie di conv seguite da un pool massimo, seguito da una connessione completa. Sembra un approccio generico per lavorare con le immagini. Ma poi di nuovo, questo è il punto centrale della mia domanda originale. Sembra che tutti questi commenti, anche se perspicaci per me, mancano completamente il punto di quello che sto chiedendo in primo luogo.
rodrigo-silveira,

1
y[80,350]θy

Risposte:


42

Prima di tutto un suggerimento generale: fai una ricerca in letteratura prima di iniziare a fare esperimenti su un argomento che non conosci. Ti risparmierai molto tempo.

In questo caso, guardando documenti esistenti potresti averlo notato

  1. Le CNN sono state usate più volte per la regressione: questo è un classico ma è vecchio (sì, 3 anni è vecchio in DL). Un documento più moderno non avrebbe usato AlexNet per questo compito. Questo è più recente, ma è per un problema molto più complicato (rotazione 3D), e comunque non ne ho familiarità.
  2. La regressione con CNN non è un problema banale. Guardando di nuovo il primo documento, vedrai che hanno un problema in cui possono praticamente generare dati infiniti. Il loro obiettivo è prevedere l'angolo di rotazione necessario per correggere le immagini 2D. Ciò significa che posso praticamente prendere il mio set di allenamento e aumentarlo ruotando ogni immagine di angoli arbitrari, e otterrò un set di allenamento valido e più grande. Quindi il problema sembra relativamente semplice, per quanto riguarda i problemi di Deep Learning. A proposito, nota gli altri trucchi di aumento dei dati che usano:

    Utilizziamo traduzioni (fino al 5% della larghezza dell'immagine), regolazione della luminosità nell'intervallo [−0,2, 0,2], regolazione gamma con γ ∈ [−0,5, 0,1] e rumore pixel gaussiano con una deviazione standard nell'intervallo [0 , 0,02].

    k

    yxα=atan2(y,x)>11%dell'errore massimo possibile. Hanno fatto leggermente meglio usando due reti in serie: la prima avrebbe eseguito la classificazione (prevedere se l'angolo sarebbe tra o ), quindi l'immagine, ruotata della quantità prevista dalla prima rete, verrebbe inviata ad un'altra rete neurale (per regressione, questa volta), che prevederebbe la rotazione aggiuntiva finale in l' intervallo .[180°,90°],[90°,0°],[0°,90°][90°,180°][45°,45°]

    Su un problema molto più semplice (MNIST ruotato), puoi ottenere qualcosa di meglio , ma comunque non vai al di sotto di un errore RMSE che è il dell'errore massimo possibile.2.6%

Quindi, cosa possiamo imparare da questo? Prima di tutto, che 5000 immagini sono un piccolo set di dati per la tua attività. Il primo documento utilizzava una rete che era stata pre-stampata su immagini simili a quella per cui volevano apprendere l'attività di regressione: non solo devi imparare un compito diverso da quello per cui l'architettura è stata progettata (classificazione), ma il tuo set di formazione non lo fa non assomigli affatto ai set di addestramento su cui queste reti sono solitamente addestrate (CIFAR-10/100 o ImageNet). Quindi probabilmente non trarrai alcun vantaggio dall'apprendimento del trasferimento. L'esempio MATLAB aveva 5000 immagini, ma erano in bianco e nero e semanticamente tutte molto simili (beh, anche questo potrebbe essere il tuo caso).

Quindi, quanto è realistico fare meglio di 0,3? Dobbiamo prima di tutto capire cosa intendi per 0,3 perdite medie. Vuoi dire che l'errore RMSE è 0,3,

1Ni=1N(h(xi)yi)2

dove è la dimensione del tuo set di allenamento (quindi, ), è l'output della tua CNN per l'immagine e è la concentrazione corrispondente della sostanza chimica? Dato che , supponendo quindi di tagliare le previsioni della CNN tra 80 e 350 (o di usare semplicemente un logit per adattarle a quell'intervallo), si ottiene un errore inferiore allo . Seriamente, cosa ti aspetti? non mi sembra affatto un grosso errore.NN<5000h(xi)xiyiyi[80,350]0.12%

Inoltre, prova a calcolare il numero di parametri nella tua rete: sono di fretta e potrei fare errori stupidi, quindi ricontrolla i miei calcoli con qualche summaryfunzione da qualunque framework tu stia usando. Tuttavia, approssimativamente direi che hai

9×(3×32+2×32×32+32×64+2×64×64+64×128+2×128×128)+128×128+128×32+32×32×32=533344

(nota che ho saltato i parametri dei livelli della norma batch, ma sono solo 4 parametri per il livello, quindi non fanno differenza). Hai mezzo milione di parametri e 5000 esempi ... cosa ti aspetteresti? Certo, il numero di parametri non è un buon indicatore della capacità di una rete neurale (è un modello non identificabile), ma comunque ... Non penso che tu possa fare molto meglio di questo, ma puoi provare un poche cose:

  • normalizzare tutti gli input (ad esempio, ridimensionare le intensità RGB di ciascun pixel tra -1 e 1 o utilizzare la standardizzazione) e tutti gli output. Ciò sarà particolarmente utile in caso di problemi di convergenza.
  • vai in scala di grigi: questo ridurrebbe i tuoi canali di input da 3 a 1. Tutte le tue immagini sembrano (ai miei occhi non addestrati) di colori relativamente simili. Sei sicuro che sia il colore necessario per prevedere e non l'esistenza di aree più scure o più luminose? Forse sei sicuro (non sono un esperto): in questo caso salta questo suggerimento.y
  • Dati aumento: dal momento che lei ha detto che flipping, rotazione di un angolo arbitrario o il mirroring le immagini dovrebbero comportare la stessa uscita, è possibile aumentare la dimensione dei dati impostato un sacco . Si noti che con un set di dati più grande l'errore sul set di allenamento aumenterà: quello che stiamo cercando qui è un divario minore tra la perdita del set di allenamento e la perdita del set di test. Inoltre, se la perdita del set di allenamento aumenta molto, questa potrebbe essere una buona notizia: può significare che puoi allenare una rete più profonda su questo set di allenamento più grande senza il rischio di un overfitting. Prova ad aggiungere più livelli e vedi se ora ottieni un set di allenamento più piccolo e verifica la perdita del set. Infine, potresti provare anche gli altri trucchi di aumento dei dati che ho citato sopra, se hanno senso nel contesto della tua applicazione.
  • usa il trucco di classificazione, quindi di regressione: una prima rete determina solo se dovrebbe trovarsi in uno, diciamo, di 10 bin, come , ecc. Una seconda rete calcola quindi una correzione : centrare e normalizzare può aiutare anche qui. Non posso dire senza provare.y[80,97],[97,124][0,27]
  • prova a utilizzare un'architettura moderna (Inception o ResNet) anziché un'architettura vintage. ResNet ha in realtà meno parametri di VGG-net. Ovviamente, vuoi usare le piccole ResNet qui - Non penso che ResNet-101 possa aiutare su un set di dati di 5000 immagini. Puoi aumentare molto il set di dati, però ....
  • Dato che l'output è invariante alla rotazione, un'altra grande idea sarebbe quella di utilizzare CNN equivarianti di gruppo , il cui output (quando utilizzato come classificatori) è invariante a rotazioni discrete o CNN orientabiliil cui output è invariante alle rotazioni continue. La proprietà invarianza ti consentirebbe di ottenere buoni risultati con un aumento molto minore dei dati, o idealmente nessuno (per quanto riguarda le rotazioni: ovviamente hai ancora bisogno degli altri tipi di da). Le CNN equivalenti di gruppo sono più mature delle CNN orientabili dal punto di vista dell'implementazione, quindi proverei prima le CNN di gruppo. Puoi provare la classificazione-poi-regressione, usando G-CNN per la parte di classificazione, oppure puoi sperimentare l'approccio di regressione pura. Ricorda di cambiare il livello superiore di conseguenza.
  • sperimentare le dimensioni del batch (sì, sì, so che l'hyperparametro-hacking non è bello, ma questo è il migliore che potrei venire in un lasso di tempo limitato e gratuitamente :-)
  • infine, ci sono architetture che sono state appositamente sviluppate per fare previsioni accurate con piccoli set di dati. La maggior parte di loro utilizzava convoluzioni dilatate : un famoso esempio è la fitta rete neurale convoluzionale . L'implementazione non è banale, però.

3
Grazie per la risposta dettagliata Avevo già fatto un significativo aumento dei dati. Ho provato un paio di varianti del modello di inizio (dove una variazione significa che il numero di filtri è ridimensionato equamente su tutto il modello). Ho visto incredibili miglioramenti. Ho ancora molta strada da fare. Proverò alcuni dei tuoi suggerimenti. Grazie ancora.
rodrigo-silveira,

@ rodrigo-silveira, prego, fammi sapere come va. Forse potremo parlare in chat una volta ottenuti i risultati.
DeltaIV,

1
Ottima risposta, merita di più ^
Gilly,

1
Molto ben composto!
Karthik Thiagarajan,

1
Se potessi, ti darei 10k punti per questo. Risposta incredibile
Boppity Bop,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.