In poche parole, torch.Tensor.view()che è ispirato da numpy.ndarray.reshape()o numpy.reshape(), crea una nuova vista del tensore, purché la nuova forma sia compatibile con la forma del tensore originale.
Comprendiamolo in dettaglio usando un esempio concreto.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Con questo tensore tdi forma (18,), le nuove viste possono essere create solo per le seguenti forme:
(1, 18)o equivalentemente (1, -1)o o equivalentemente o o equivalentemente o o equivalentemente o o equivalentemente o o equivalentemente o(-1, 18)
(2, 9)(2, -1)(-1, 9)
(3, 6)(3, -1)(-1, 6)
(6, 3)(6, -1)(-1, 3)
(9, 2)(9, -1)(-1, 2)
(18, 1)(18, -1)(-1, 1)
Come possiamo già osservare dalle tuple di forma sopra, la moltiplicazione degli elementi della tupla di forma (ad es . 2*9, 3*6Ecc.) Deve sempre essere uguale al numero totale di elementi nel tensore originale (18 nel nostro esempio).
Un'altra cosa da osservare è che abbiamo usato a -1in uno dei punti in ciascuna delle tuple di forma. Usando a -1, siamo pigri nel fare noi stessi il calcolo e piuttosto deleghiamo l'attività a PyTorch per fare il calcolo di quel valore per la forma quando crea la nuova vista . Una cosa importante da notare è che possiamo usare solo un singolo -1nella tupla di forma. I valori rimanenti devono essere esplicitamente forniti da noi. Else PyTorch si lamenterà lanciando un RuntimeError:
RuntimeError: è possibile dedurre solo una dimensione
Quindi, con tutte le forme sopra menzionate, PyTorch restituirà sempre una nuova vista del tensore originalet . Ciò significa sostanzialmente che cambia semplicemente le informazioni sul passo del tensore per ciascuna delle nuove viste richieste.
Di seguito sono riportati alcuni esempi che illustrano come i passi dei tensori vengono modificati con ogni nuova vista .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Ora vedremo i passi da compiere per le nuove viste :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Quindi questa è la magia della view()funzione. Cambia solo i passi del tensore (originale) per ciascuna delle nuove viste , purché la forma della nuova vista sia compatibile con la forma originale.
Un'altra cosa interessante potrebbe osservare dalle stride tuple è che il valore dell'elemento nella 0 ° posizione è uguale al valore dell'elemento nella 1 ° posizione della tupla forma.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Questo è perché:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
il passo (6, 1)dice che per passare da un elemento all'elemento successivo lungo la 0a dimensione, dobbiamo saltare o fare 6 passi. (cioè per passare da 0a 6, si deve prendere 6 punti). Ma per passare da un elemento con l'elemento successivo nella 1 ° dimensione, abbiamo solo bisogno di un solo passo (per esempio per andare da 2a3 ).
Pertanto, le informazioni sui passi sono al centro di come gli elementi sono accessibili dalla memoria per eseguire il calcolo.
Questa funzione restituirebbe una vista ed è esattamente uguale all'utilizzo torch.Tensor.view()purché la nuova forma sia compatibile con la forma del tensore originale. Altrimenti, restituirà una copia.
Tuttavia, le note di torch.reshape()avvertono che:
gli input contigui e gli input con passi compatibili possono essere rimodellati senza copiare, ma non si dovrebbe dipendere dal comportamento di copia vs. visualizzazione.
reshapein PyTorch ?!