Differenza tra Variable e get_variable in TensorFlow


125

Per quanto ne so, Variableè l'operazione predefinita per creare una variabile e get_variableviene utilizzata principalmente per la condivisione del peso.

Da un lato, ci sono alcune persone che suggeriscono di utilizzare get_variableinvece Variabledell'operazione primitiva ogni volta che è necessaria una variabile. D'altra parte, vedo semplicemente qualsiasi uso di get_variablenei documenti ufficiali e nelle demo di TensorFlow.

Quindi voglio conoscere alcune regole pratiche su come utilizzare correttamente questi due meccanismi. Esistono principi "standard"?


6
get_variable è un modo nuovo, Variable è un modo vecchio (che potrebbe essere supportato per sempre) come dice Lukasz (PS: ha scritto gran parte dell'ambito del nome della variabile in TF)
Yaroslav Bulatov

Risposte:


90

Consiglio di usarlo sempre tf.get_variable(...): renderà molto più semplice il refactoring del codice se è necessario condividere variabili in qualsiasi momento, ad esempio in un'impostazione multi-gpu (vedere l'esempio CIFAR multi-gpu). Non ci sono svantaggi.

Pure tf.Variableè di livello inferiore; a un certo punto tf.get_variable()non esisteva quindi del codice utilizza ancora il modo di basso livello.


5
Grazie mille per la tua risposta. Ma ho ancora una domanda su come sostituire tf.Variablecon tf.get_variableovunque. Questo è quando voglio inizializzare una variabile con un array numpy, non riesco a trovare un modo pulito ed efficiente per farlo come faccio con tf.Variable. Come risolverlo? Grazie.
Lifu Huang

69

tf.Variable è una classe e ci sono diversi modi per creare tf.Variable tra cui tf.Variable.__init__e tf.get_variable.

tf.Variable.__init__: Crea una nuova variabile con valore_iniziale .

W = tf.Variable(<initial-value>, name=<optional-name>)

tf.get_variable: Ottiene una variabile esistente con questi parametri o ne crea una nuova. Puoi anche usare l'inizializzatore.

W = tf.get_variable(name, shape=None, dtype=tf.float32, initializer=None,
       regularizer=None, trainable=True, collections=None)

È molto utile utilizzare inizializzatori come xavier_initializer:

W = tf.get_variable("W", shape=[784, 256],
       initializer=tf.contrib.layers.xavier_initializer())

Maggiori informazioni qui .


Sì, in Variablerealtà intendo usare il suo __init__. Dato che get_variableè così conveniente, mi chiedo perché la maggior parte del codice TensorFlow che ho visto usare al Variableposto di get_variable. Ci sono convenzioni o fattori da considerare quando si sceglie tra di loro. Grazie!
Lifu Huang

Se vuoi avere un certo valore, usare Variable è semplice: x = tf.Variable (3).
Sung Kim

@SungKim normalmente quando usiamo tf.Variable()possiamo inizializzarlo come un valore casuale da una distribuzione normale troncata. Ecco il mio esempio w1 = tf.Variable(tf.truncated_normal([5, 50], stddev = 0.01), name = 'w1'). Quale sarebbe l'equivalente di questo? come faccio a dirgli che voglio un normale troncato? Dovrei solo fare w1 = tf.get_variable(name = 'w1', shape = [5,50], initializer = tf.truncated_normal, regularizer = tf.nn.l2_loss)?
Euler_Salter

@Euler_Salter: puoi usare tf.truncated_normal_initializer()per ottenere il risultato desiderato.
Beta

46

Riesco a trovare due differenze principali tra l'una e l'altra:

  1. Il primo è che tf.Variablecreerà sempre una nuova variabile, mentre tf.get_variableottiene una variabile esistente con parametri specificati dal grafico e, se non esiste, ne crea una nuova.

  2. tf.Variable richiede che venga specificato un valore iniziale.

È importante chiarire che la funzione tf.get_variableantepone al nome l'ambito della variabile corrente per eseguire i controlli di riutilizzo. Per esempio:

with tf.variable_scope("one"):
    a = tf.get_variable("v", [1]) #a.name == "one/v:0"
with tf.variable_scope("one"):
    b = tf.get_variable("v", [1]) #ValueError: Variable one/v already exists
with tf.variable_scope("one", reuse = True):
    c = tf.get_variable("v", [1]) #c.name == "one/v:0"

with tf.variable_scope("two"):
    d = tf.get_variable("v", [1]) #d.name == "two/v:0"
    e = tf.Variable(1, name = "v", expected_shape = [1]) #e.name == "two/v_1:0"

assert(a is c)  #Assertion is true, they refer to the same object.
assert(a is d)  #AssertionError: they are different objects
assert(d is e)  #AssertionError: they are different objects

L'ultimo errore di asserzione è interessante: si suppone che due variabili con lo stesso nome sotto lo stesso ambito siano la stessa variabile. Ma se provi i nomi delle variabili de eti renderai conto che Tensorflow ha cambiato il nome della variabile e:

d.name   #d.name == "two/v:0"
e.name   #e.name == "two/v_1:0"

Ottimo esempio! Per quanto riguarda d.namee e.name, mi sono appena imbattuto in questo documento TensorFlow sull'operazione di denominazione dei grafici tensoriali che lo spiega:If the default graph already contained an operation named "answer", the TensorFlow would append "_1", "_2", and so on to the name, in order to make it unique.
Atlas7

2

Un'altra differenza sta nel fatto che uno è in ('variable_store',)collezione ma l'altro no.

Si prega di consultare il codice sorgente :

def _get_default_variable_store():
  store = ops.get_collection(_VARSTORE_KEY)
  if store:
    return store[0]
  store = _VariableStore()
  ops.add_to_collection(_VARSTORE_KEY, store)
  return store

Lasciatemi illustrare che:

import tensorflow as tf
from tensorflow.python.framework import ops

embedding_1 = tf.Variable(tf.constant(1.0, shape=[30522, 1024]), name="word_embeddings_1", dtype=tf.float32) 
embedding_2 = tf.get_variable("word_embeddings_2", shape=[30522, 1024])

graph = tf.get_default_graph()
collections = graph.collections

for c in collections:
    stores = ops.get_collection(c)
    print('collection %s: ' % str(c))
    for k, store in enumerate(stores):
        try:
            print('\t%d: %s' % (k, str(store._vars)))
        except:
            print('\t%d: %s' % (k, str(store)))
    print('')

Il risultato:

collection ('__variable_store',): 0: {'word_embeddings_2': <tf.Variable 'word_embeddings_2:0' shape=(30522, 1024) dtype=float32_ref>}

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.