TensorFlow, perché ci sono 3 file dopo aver salvato il modello?


113

Dopo aver letto i documenti , ho salvato un modello in TensorFlow, ecco il mio codice demo:

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  ..
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in file: %s" % save_path)

ma dopo ho scoperto che ci sono 3 file

model.ckpt.data-00000-of-00001
model.ckpt.index
model.ckpt.meta

E non posso ripristinare il modello ripristinando il model.ckptfile, poiché non esiste un file di questo tipo. Ecco il mio codice

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Allora, perché ci sono 3 file?


2
Hai capito come affrontare questo problema? Come posso caricare di nuovo il modello (utilizzando Keras)?
rajkiran

Risposte:


116

Prova questo:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/tmp/model.ckpt.meta')
    saver.restore(sess, "/tmp/model.ckpt")

Il metodo di salvataggio TensorFlow salva tre tipi di file poiché memorizza la struttura del grafico separatamente dai valori delle variabili . Il .metafile descrive la struttura del grafico salvato, quindi è necessario importarlo prima di ripristinare il checkpoint (altrimenti non sa a quali variabili corrispondono i valori del checkpoint salvato).

In alternativa, potresti farlo:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")

...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/tmp/model.ckpt")

Anche se non è presente alcun file denominato model.ckpt, fai comunque riferimento al checkpoint salvato con quel nome quando lo ripristini. Dal saver.pycodice sorgente :

Gli utenti devono solo interagire con il prefisso specificato dall'utente ... invece di qualsiasi percorso fisico.


1
quindi il .index e il .data non vengono utilizzati? Quando vengono utilizzati quei 2 file, allora?
ajfbiw.s

26
@ ajfbiw.s .meta memorizza la struttura del grafico, .data memorizza i valori di ogni variabile nel grafico, .index identifica il checkpiont. Quindi nell'esempio sopra: import_meta_graph usa .meta e saver.restore usa .data e .index
TK Bartel

Oh, capisco. Grazie.
ajfbiw.s

1
C'è qualche possibilità che tu abbia salvato il modello con una versione di TensorFlow diversa da quella che stai utilizzando per caricarlo? ( github.com/tensorflow/tensorflow/issues/5639 )
TK Bartel

5
Qualcuno sa cosa significano questo 00000e i 00001numeri? in variables.data-?????-of-?????archivio
Ivan Talalaev

55
  • meta file : descrive la struttura del grafico salvato, include GraphDef, SaverDef e così via; quindi applicare tf.train.import_meta_graph('/tmp/model.ckpt.meta'), ripristinerà Savere Graph.

  • file indice : è una tabella immutabile stringa-stringa (tensorflow :: table :: Table). Ogni chiave è un nome di un tensore e il suo valore è un BundleEntryProto serializzato. Ogni BundleEntryProto descrive i metadati di un tensore: quale dei file "dati" contiene il contenuto di un tensore, l'offset in quel file, il checksum, alcuni dati ausiliari, ecc.

  • file di dati : è la raccolta TensorBundle, salva i valori di tutte le variabili.


Ho il file pb che ho per la classificazione delle immagini. Posso usarlo per la classificazione dei video in tempo reale?

Potete farmi sapere, utilizzando Keras 2, come faccio a caricare il modello se viene salvato come 3 file?
rajkiran

5

Sto ripristinando i word embedding addestrati dal tutorial di Word2Vec tensorflow.

Nel caso in cui tu abbia creato più checkpoint:

ad esempio, i file creati hanno questo aspetto

model.ckpt-55695.data-00000-of-00001

model.ckpt-55695.index

model.ckpt-55695.meta

prova questo

def restore_session(self, session):
   saver = tf.train.import_meta_graph('./tmp/model.ckpt-55695.meta')
   saver.restore(session, './tmp/model.ckpt-55695')

quando si chiama restore_session ():

def test_word2vec():
   opts = Options()    
   with tf.Graph().as_default(), tf.Session() as session:
       with tf.device("/cpu:0"):            
           model = Word2Vec(opts, session)
           model.restore_session(session)
           model.get_embedding("assistance")

Cosa significa "00000-of-00001" in "model.ckpt-55695.data-00000-of-00001"?
hafiz031

0

Ad esempio, se hai addestrato una CNN con abbandono scolastico, potresti farlo:

def predict(image, model_name):
    """
    image -> single image, (width, height, channels)
    model_name -> model file that was saved without any extensions
    """
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph('./' + model_name + '.meta')
        saver.restore(sess, './' + model_name)
        # Substitute 'logits' with your model
        prediction = tf.argmax(logits, 1)
        # 'x' is what you defined it to be. In my case it is a batch of RGB images, that's why I add the extra dimension
        return prediction.eval(feed_dict={x: image[np.newaxis,:,:,:], keep_prob_dnn: 1.0})
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.