In poche parole, torch.Tensor.view()
che è ispirato da numpy.ndarray.reshape()
o numpy.reshape()
, crea una nuova vista del tensore, purché la nuova forma sia compatibile con la forma del tensore originale.
Comprendiamolo in dettaglio usando un esempio concreto.
In [43]: t = torch.arange(18)
In [44]: t
Out[44]:
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])
Con questo tensore t
di forma (18,)
, le nuove viste possono essere create solo per le seguenti forme:
(1, 18)
o equivalentemente (1, -1)
o o equivalentemente o o equivalentemente o o equivalentemente o o equivalentemente o o equivalentemente o(-1, 18)
(2, 9)
(2, -1)
(-1, 9)
(3, 6)
(3, -1)
(-1, 6)
(6, 3)
(6, -1)
(-1, 3)
(9, 2)
(9, -1)
(-1, 2)
(18, 1)
(18, -1)
(-1, 1)
Come possiamo già osservare dalle tuple di forma sopra, la moltiplicazione degli elementi della tupla di forma (ad es . 2*9
, 3*6
Ecc.) Deve sempre essere uguale al numero totale di elementi nel tensore originale (18
nel nostro esempio).
Un'altra cosa da osservare è che abbiamo usato a -1
in uno dei punti in ciascuna delle tuple di forma. Usando a -1
, siamo pigri nel fare noi stessi il calcolo e piuttosto deleghiamo l'attività a PyTorch per fare il calcolo di quel valore per la forma quando crea la nuova vista . Una cosa importante da notare è che possiamo usare solo un singolo -1
nella tupla di forma. I valori rimanenti devono essere esplicitamente forniti da noi. Else PyTorch si lamenterà lanciando un RuntimeError
:
RuntimeError: è possibile dedurre solo una dimensione
Quindi, con tutte le forme sopra menzionate, PyTorch restituirà sempre una nuova vista del tensore originalet
. Ciò significa sostanzialmente che cambia semplicemente le informazioni sul passo del tensore per ciascuna delle nuove viste richieste.
Di seguito sono riportati alcuni esempi che illustrano come i passi dei tensori vengono modificati con ogni nuova vista .
# stride of our original tensor `t`
In [53]: t.stride()
Out[53]: (1,)
Ora vedremo i passi da compiere per le nuove viste :
# shape (1, 18)
In [54]: t1 = t.view(1, -1)
# stride tensor `t1` with shape (1, 18)
In [55]: t1.stride()
Out[55]: (18, 1)
# shape (2, 9)
In [56]: t2 = t.view(2, -1)
# stride of tensor `t2` with shape (2, 9)
In [57]: t2.stride()
Out[57]: (9, 1)
# shape (3, 6)
In [59]: t3 = t.view(3, -1)
# stride of tensor `t3` with shape (3, 6)
In [60]: t3.stride()
Out[60]: (6, 1)
# shape (6, 3)
In [62]: t4 = t.view(6,-1)
# stride of tensor `t4` with shape (6, 3)
In [63]: t4.stride()
Out[63]: (3, 1)
# shape (9, 2)
In [65]: t5 = t.view(9, -1)
# stride of tensor `t5` with shape (9, 2)
In [66]: t5.stride()
Out[66]: (2, 1)
# shape (18, 1)
In [68]: t6 = t.view(18, -1)
# stride of tensor `t6` with shape (18, 1)
In [69]: t6.stride()
Out[69]: (1, 1)
Quindi questa è la magia della view()
funzione. Cambia solo i passi del tensore (originale) per ciascuna delle nuove viste , purché la forma della nuova vista sia compatibile con la forma originale.
Un'altra cosa interessante potrebbe osservare dalle stride tuple è che il valore dell'elemento nella 0 ° posizione è uguale al valore dell'elemento nella 1 ° posizione della tupla forma.
In [74]: t3.shape
Out[74]: torch.Size([3, 6])
|
In [75]: t3.stride() |
Out[75]: (6, 1) |
|_____________|
Questo è perché:
In [76]: t3
Out[76]:
tensor([[ 0, 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17]])
il passo (6, 1)
dice che per passare da un elemento all'elemento successivo lungo la 0a dimensione, dobbiamo saltare o fare 6 passi. (cioè per passare da 0
a 6
, si deve prendere 6 punti). Ma per passare da un elemento con l'elemento successivo nella 1 ° dimensione, abbiamo solo bisogno di un solo passo (per esempio per andare da 2
a3
).
Pertanto, le informazioni sui passi sono al centro di come gli elementi sono accessibili dalla memoria per eseguire il calcolo.
Questa funzione restituirebbe una vista ed è esattamente uguale all'utilizzo torch.Tensor.view()
purché la nuova forma sia compatibile con la forma del tensore originale. Altrimenti, restituirà una copia.
Tuttavia, le note di torch.reshape()
avvertono che:
gli input contigui e gli input con passi compatibili possono essere rimodellati senza copiare, ma non si dovrebbe dipendere dal comportamento di copia vs. visualizzazione.
reshape
in PyTorch ?!