La libreria pickle Python implementa protocolli binari per serializzare e deserializzare un oggetto Python.
Quando tu import torch
(o quando usi PyTorch) lo farà import pickle
per te e non hai bisogno di chiamare pickle.dump()
e pickle.load()
direttamente, quali sono i metodi per salvare e caricare l'oggetto.
In effetti, torch.save()
e torch.load()
avvolgerà pickle.dump()
e pickle.load()
per te.
Un'altra state_dict
risposta menzionata merita solo qualche altra nota.
Cosa state_dict
abbiamo dentro PyTorch? In realtà ci sono due state_dict
secondi.
Il modello è PyTorch torch.nn.Module
ha model.parameters()
chiamata per ottenere i parametri apprendibili (w eb). Questi parametri apprendibili, una volta impostati casualmente, si aggiorneranno nel tempo man mano che apprendiamo. I parametri apprendibili sono i primi state_dict
.
Il secondo state_dict
è lo stato dell'ottimizzatore. Ricordi che l'ottimizzatore viene utilizzato per migliorare i nostri parametri apprendibili. Ma l'ottimizzatore state_dict
è stato risolto. Niente da imparare lì dentro.
Poiché gli state_dict
oggetti sono dizionari Python, possono essere facilmente salvati, aggiornati, modificati e ripristinati, aggiungendo molta modularità ai modelli e agli ottimizzatori PyTorch.
Creiamo un modello super semplice per spiegare questo:
import torch
import torch.optim as optim
model = torch.nn.Linear(5, 2)
# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print("Model's state_dict:")
for param_tensor in model.state_dict():
print(param_tensor, "\t", model.state_dict()[param_tensor].size())
print("Model weight:")
print(model.weight)
print("Model bias:")
print(model.bias)
print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
print(var_name, "\t", optimizer.state_dict()[var_name])
Questo codice genererà quanto segue:
Model's state_dict:
weight torch.Size([2, 5])
bias torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328, 0.1360, 0.1553, -0.1838, -0.0316],
[ 0.0479, 0.1760, 0.1712, 0.2244, 0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]
Nota che questo è un modello minimo. Puoi provare ad aggiungere una pila di sequenziali
model = torch.nn.Sequential(
torch.nn.Linear(D_in, H),
torch.nn.Conv2d(A, B, C)
torch.nn.Linear(H, D_out),
)
Si noti che solo i livelli con parametri apprendibili (livelli convoluzionali, livelli lineari, ecc.) E buffer registrati (livelli batchnorm) hanno voci nei modelli state_dict
.
Le cose non apprendibili appartengono all'ottimizzatore state_dict
, che contiene informazioni sullo stato dell'ottimizzatore e sugli iperparametri utilizzati.
Il resto della storia è lo stesso; nella fase di inferenza (questa è una fase in cui utilizziamo il modello dopo l'allenamento) per la previsione; prevediamo in base ai parametri appresi. Quindi per l'inferenza, dobbiamo solo salvare i parametri model.state_dict()
.
torch.save(model.state_dict(), filepath)
E per utilizzare successivamente model.load_state_dict (torch.load (filepath)) model.eval ()
Nota: non dimenticare l'ultima riga che model.eval()
è cruciale dopo aver caricato il modello.
Inoltre, non provare a salvare torch.save(model.parameters(), filepath)
. Il model.parameters()
è solo l'oggetto generatore.
Dall'altro lato, torch.save(model, filepath)
salva l'oggetto modello stesso, ma tieni presente che il modello non ha l'ottimizzatore state_dict
. Controlla l'altra eccellente risposta di @Jadiel de Armas per salvare il dict di stato dell'ottimizzatore.