Il modo migliore per salvare un modello addestrato in PyTorch?


195

Stavo cercando modi alternativi per salvare un modello addestrato in PyTorch. Finora ho trovato due alternative.

  1. torch.save () per salvare un modello e torch.load () per caricare un modello.
  2. model.state_dict () per salvare un modello addestrato e model.load_state_dict () per caricare il modello salvato.

Mi sono imbattuto in questa discussione in cui l'approccio 2 è raccomandato rispetto all'approccio 1.

La mia domanda è: perché è preferito il secondo approccio? È solo perché i moduli torch.nn hanno queste due funzioni e siamo incoraggiati a usarle?


2
Penso che sia perché torch.save () salva anche tutte le variabili intermedie, come le uscite intermedie per l'uso della propagazione posteriore. Ma devi solo salvare i parametri del modello, come peso / distorsione, ecc. A volte il primo può essere molto più grande del secondo.
Dawei Yang,

2
Ho testato torch.save(model, f)e torch.save(model.state_dict(), f). I file salvati hanno le stesse dimensioni. Ora sono confuso. Inoltre, ho trovato l'utilizzo di pickle per salvare model.state_dict () estremamente lento. Penso che il modo migliore sia quello di utilizzare la torch.save(model.state_dict(), f)gestione del modello e la torcia gestisce il caricamento dei pesi del modello, eliminando così possibili problemi. Riferimento: discuss.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

Sembra che PyTorch abbia affrontato questo argomento in modo un po 'più esplicito nella sezione tutorial: ci sono molte buone informazioni lì che non sono elencate nelle risposte qui, incluso il salvataggio di più di un modello alla volta e modelli di avvio a caldo.
whlteXbread

cosa c'è di sbagliato nell'usare pickle?
Charlie Parker,

1
@CharlieParker torch.save si basa sul sottaceto. Quanto segue è dal tutorial collegato sopra: "[torch.save] salverà l'intero modulo usando il modulo pickle di Python. Lo svantaggio di questo approccio è che i dati serializzati sono associati alle classi specifiche e alla struttura di directory esatta utilizzata quando il modello viene salvato. La ragione di ciò è perché pickle non salva la classe del modello stesso, ma salva un percorso al file contenente la classe, che viene utilizzato durante il tempo di caricamento. Per questo motivo, il codice può rompersi in vari modi quando utilizzato in altri progetti o dopo i rifrattori. "
David Miller,

Risposte:


215

Ho trovato questa pagina nel loro repository github, incollerò qui il contenuto.


Approccio raccomandato per il salvataggio di un modello

Esistono due approcci principali per serializzare e ripristinare un modello.

Il primo (consigliato) salva e carica solo i parametri del modello:

torch.save(the_model.state_dict(), PATH)

Poi più tardi:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Il secondo salva e carica l'intero modello:

torch.save(the_model, PATH)

Poi più tardi:

the_model = torch.load(PATH)

Tuttavia, in questo caso, i dati serializzati sono associati alle classi specifiche e all'esatta struttura di directory utilizzata, quindi possono rompersi in vari modi se utilizzati in altri progetti o dopo alcuni refactor gravi.


8
Secondo @smth discuss.pytorch.org/t/saving-and-loading-a-model-in-pytorch/… il modello si ricarica per addestrare il modello di default. quindi è necessario chiamare manualmente the_model.eval () dopo il caricamento, se lo si sta caricando per deduzione, non riprendere l'allenamento.
WillZ,

il secondo metodo fornisce stackoverflow.com/questions/53798009/… errore su Windows 10. non è stato in grado di risolverlo
Gulzar

Esiste un'opzione per salvare senza la necessità di un accesso per la classe del modello?
Michael D,

Con questo approccio come tenere traccia di * args e ** kwargs che è necessario passare per il caso di carico?
Mariano Kamp

cosa c'è di sbagliato nell'usare pickle?
Charlie Parker,

145

Dipende da quello che vuoi fare.

Caso n. 1: salvare il modello per utilizzarlo da solo per deduzione : si salva il modello, lo si ripristina e quindi si cambia il modello in modalità di valutazione. Questo perché di solito hai BatchNorme Dropoutlivelli che per impostazione predefinita sono in modalità treno in costruzione:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso n. 2: salvare il modello per riprendere l'allenamento in un secondo momento : se è necessario continuare ad allenare il modello che si sta per salvare, è necessario salvare più di un semplice modello. Devi anche salvare lo stato dell'ottimizzatore, le epoche, il punteggio, ecc. Lo faresti in questo modo:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Per riprendere l'allenamento faresti cose come:, state = torch.load(filepath)e poi, per ripristinare lo stato di ogni singolo oggetto, qualcosa del genere:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Poiché si riprende l'allenamento, NON chiamare model.eval()una volta ripristinati gli stati durante il caricamento.

Caso n. 3: modello che può essere utilizzato da qualcun altro senza accesso al codice : in Tensorflow è possibile creare un .pbfile che definisce sia l'architettura che i pesi del modello. Questo è molto utile, specialmente durante l'utilizzo Tensorflow serve. Il modo equivalente per farlo in Pytorch sarebbe:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

In questo modo non è ancora a prova di proiettile e poiché Pytorch sta ancora subendo molti cambiamenti, non lo consiglierei.


1
Esiste un file consigliato che termina per i 3 casi? O è sempre .pth?
Verena Haunschmid,

1
Nel caso n. 3 torch.loadrestituisce solo un OrderedDict. Come si ottiene il modello per fare previsioni?
Alber8295,

Ciao, posso sapere come fare il "Caso n. 2: Salva modello per riprendere l'allenamento in seguito"? Sono riuscito a caricare il checkpoint sul modello, quindi non sono riuscito a eseguire o riprendere a formare un modello come "model.to (dispositivo) modello = train_model_epoch (modello, criterio, ottimizzatore, sched, epoche)"
dnez

1
Salve, per il primo caso che è per inferenza, nel documento ufficiale di pytorch dire che deve salvare l'ottimizzatore state_dict sia per inferenza che per completare la formazione. "Quando si salva un checkpoint generale, da utilizzare per dedurre o riprendere l'addestramento, è necessario risparmiare più del semplice state_dict del modello. È importante salvare anche state_dict dell'ottimizzatore, poiché contiene buffer e parametri che vengono aggiornati come treni del modello "
Mohammed Awney,

1
Nel caso n. 3, la classe del modello deve essere definita da qualche parte.
Michael D,

12

La libreria pickle Python implementa protocolli binari per serializzare e deserializzare un oggetto Python.

Quando tu import torch(o quando usi PyTorch) lo farà import pickleper te e non hai bisogno di chiamare pickle.dump()e pickle.load()direttamente, quali sono i metodi per salvare e caricare l'oggetto.

In effetti, torch.save()e torch.load()avvolgerà pickle.dump()e pickle.load()per te.

Un'altra state_dictrisposta menzionata merita solo qualche altra nota.

Cosa state_dictabbiamo dentro PyTorch? In realtà ci sono due state_dictsecondi.

Il modello è PyTorch torch.nn.Moduleha model.parameters()chiamata per ottenere i parametri apprendibili (w eb). Questi parametri apprendibili, una volta impostati casualmente, si aggiorneranno nel tempo man mano che apprendiamo. I parametri apprendibili sono i primi state_dict.

Il secondo state_dictè lo stato dell'ottimizzatore. Ricordi che l'ottimizzatore viene utilizzato per migliorare i nostri parametri apprendibili. Ma l'ottimizzatore state_dictè stato risolto. Niente da imparare lì dentro.

Poiché gli state_dictoggetti sono dizionari Python, possono essere facilmente salvati, aggiornati, modificati e ripristinati, aggiungendo molta modularità ai modelli e agli ottimizzatori PyTorch.

Creiamo un modello super semplice per spiegare questo:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Questo codice genererà quanto segue:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Nota che questo è un modello minimo. Puoi provare ad aggiungere una pila di sequenziali

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Si noti che solo i livelli con parametri apprendibili (livelli convoluzionali, livelli lineari, ecc.) E buffer registrati (livelli batchnorm) hanno voci nei modelli state_dict.

Le cose non apprendibili appartengono all'ottimizzatore state_dict, che contiene informazioni sullo stato dell'ottimizzatore e sugli iperparametri utilizzati.

Il resto della storia è lo stesso; nella fase di inferenza (questa è una fase in cui utilizziamo il modello dopo l'allenamento) per la previsione; prevediamo in base ai parametri appresi. Quindi per l'inferenza, dobbiamo solo salvare i parametri model.state_dict().

torch.save(model.state_dict(), filepath)

E per utilizzare successivamente model.load_state_dict (torch.load (filepath)) model.eval ()

Nota: non dimenticare l'ultima riga che model.eval()è cruciale dopo aver caricato il modello.

Inoltre, non provare a salvare torch.save(model.parameters(), filepath). Il model.parameters()è solo l'oggetto generatore.

Dall'altro lato, torch.save(model, filepath)salva l'oggetto modello stesso, ma tieni presente che il modello non ha l'ottimizzatore state_dict. Controlla l'altra eccellente risposta di @Jadiel de Armas per salvare il dict di stato dell'ottimizzatore.


Sebbene non sia una soluzione semplice, l'essenza del problema viene analizzata a fondo! Upvote.
Jason Young,

7

Una convenzione PyTorch comune è quella di salvare i modelli utilizzando l'estensione di file .pt o .pth.

Salva / Carica intero modello Salva:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Caricare:

La classe del modello deve essere definita da qualche parte

model = torch.load(PATH)
model.eval()

4

Se si desidera salvare il modello e si desidera riprendere la formazione in un secondo momento:

GPU singola: Salva:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Caricare:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU multipla: Salva

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Caricare:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.