Qual è l'uso di torch.no_grad in pytorch?


21

Sono nuovo di Pytorch e ho iniziato con questo codice github. Non capisco il commento nella riga 60-61 nel codice "because weights have requires_grad=True, but we don't need to track this in autograd". Ho capito che menzioniamo requires_grad=Truele variabili di cui abbiamo bisogno per calcolare i gradienti per l'utilizzo di autograd, ma cosa significa essere "tracked by autograd"?

Risposte:


24

Il wrapper "con torch.no_grad ()" ha temporaneamente impostato tutti i flag require_grad su false. Un esempio dal tutorial ufficiale di PyTorch ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Su:

True
True
False

Ti consiglio di leggere tutti i tutorial dal sito Web sopra.

Nel tuo esempio: suppongo che l'autore non voglia che PyTorch calcoli i gradienti delle nuove variabili definite w1 e w2 poiché vuole solo aggiornare i loro valori.


6
with torch.no_grad()

farà sì che tutte le operazioni nel blocco non abbiano pendenze.

In pytorch, non è possibile eseguire il cambio di posizionamento di w1 e w2, che sono due variabili con require_grad = True. Penso che evitare il cambio di posizionamento di w1 e w2 sia perché causerà errori nel calcolo della propagazione posteriore. Poiché il cambio di collocamento cambierà totalmente w1 e w2.

Tuttavia, se lo si utilizza no_grad(), è possibile controllare il nuovo w1 e il nuovo w2 non ha pendenze poiché sono generati da operazioni, il che significa che si cambia solo il valore di w1 e w2, non parte del gradiente, hanno ancora informazioni di gradiente variabile definite in precedenza e la propagazione posteriore può continuare.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.