TensorFlow salva / carica un grafico da un file


98

Da quanto ho raccolto finora, ci sono diversi modi per scaricare un grafico TensorFlow in un file e quindi caricarlo in un altro programma, ma non sono stato in grado di trovare esempi / informazioni chiari su come funzionano. Quello che già so è questo:

  1. Salvare le variabili del modello in un file di checkpoint (.ckpt) utilizzando a tf.train.Saver()e ripristinarle in seguito ( sorgente )
  2. Salva un modello in un file .pb e caricalo di nuovo usando tf.train.write_graph()e tf.import_graph_def()( sorgente )
  3. Carica un modello da un file .pb, riqualificalo e scaricalo in un nuovo file .pb usando Bazel ( sorgente )
  4. Blocca il grafico per salvare il grafico e i pesi insieme ( fonte )
  5. Utilizzare as_graph_def()per salvare il modello e, per pesi / variabili, mapparli in costanti ( sorgente )

Tuttavia, non sono stato in grado di chiarire diverse domande su questi diversi metodi:

  1. Per quanto riguarda i file di checkpoint, salvano solo i pesi addestrati di un modello? I file del checkpoint possono essere caricati in un nuovo programma e utilizzati per eseguire il modello o servono semplicemente come modi per salvare i pesi in un modello in un determinato momento / fase?
  2. A proposito tf.train.write_graph(), vengono salvati anche i pesi / variabili?
  3. Per quanto riguarda Bazel, può solo salvare in / caricare da file .pb per la riqualificazione? Esiste un semplice comando Bazel solo per eseguire il dump di un grafico in un .pb?
  4. Per quanto riguarda il congelamento, è possibile caricare un grafico congelato utilizzando tf.import_graph_def()?
  5. La demo Android per TensorFlow viene caricata nel modello Inception di Google da un file .pb. Se volessi sostituire il mio file .pb, come dovrei fare per farlo? Avrei bisogno di cambiare codice / metodo nativo?
  6. In generale, qual è esattamente la differenza tra tutti questi metodi? O più in generale, qual è la differenza tra as_graph_def()/.ckpt/.pb?

In breve, quello che sto cercando è un metodo per salvare sia un grafico (come in, le varie operazioni e simili) che i suoi pesi / variabili in un file, che può poi essere utilizzato per caricare il grafico e i pesi in un altro programma , per l'uso (non necessariamente continuazione / riqualificazione).

La documentazione su questo argomento non è molto semplice, quindi qualsiasi risposta / informazione sarebbe molto apprezzata.


2
L'API più recente / più completa è il meta grafico, che ti darà un modo per salvare tutti e tre in una volta: 1) grafico 2) valori dei parametri 3) raccolte: tensorflow.org/versions/r0.10/how_tos/meta_graph/ index.html
Yaroslav Bulatov

Risposte:


80

Esistono molti modi per affrontare il problema del salvataggio di un modello in TensorFlow, il che può creare un po 'di confusione. Prendendo a turno ciascuna delle tue domande secondarie:

  1. I file del punto di controllo (prodotti ad esempio invocando saver.save()un tf.train.Saveroggetto) contengono solo i pesi e qualsiasi altra variabile definita nello stesso programma. Per usarli in un altro programma, è necessario ricreare la struttura del grafico associata (ad esempio eseguendo il codice per ricostruirlo o chiamando tf.import_graph_def()), che dice a TensorFlow cosa fare con quei pesi. Notare che la chiamata saver.save()produce anche un file contenente a MetaGraphDef, che contiene un grafico e dettagli su come associare i pesi da un checkpoint a quel grafico. Guarda il tutorial per maggiori dettagli.

  2. tf.train.write_graph()scrive solo la struttura del grafico; non i pesi.

  3. Bazel non è correlato alla lettura o alla scrittura di grafici TensorFlow. (Forse fraintendo la tua domanda: sentiti libero di chiarirla in un commento.)

  4. È possibile caricare un grafico congelato utilizzando tf.import_graph_def(). In questo caso, i pesi sono (in genere) incorporati nel grafico, quindi non è necessario caricare un checkpoint separato.

  5. La modifica principale consiste nell'aggiornamento dei nomi dei tensori inseriti nel modello e dei nomi dei tensori recuperati dal modello. Nella demo di TensorFlow Android, questo corrisponderebbe alle stringhe inputNamee a outputNamecui vengono passati TensorFlowClassifier.initializeTensorFlow().

  6. La GraphDefè la struttura del programma, che in genere non cambia attraverso il processo di formazione. Il checkpoint è un'istantanea dello stato di un processo di formazione, che in genere cambia in ogni fase del processo di formazione. Di conseguenza, TensorFlow utilizza diversi formati di archiviazione per questi tipi di dati e l'API di basso livello offre diversi modi per salvarli e caricarli. Le librerie di livello superiore, come le MetaGraphDeflibrerie, Keras e skflow si basano su questi meccanismi per fornire modi più convenienti per salvare e ripristinare un intero modello.


Questo significa che la documentazione dell'API C ++ si trova quando dice che puoi caricare il grafico salvato con tf.train.write_graph()e quindi eseguirlo?
mnicky

2
La documentazione dell'API C ++ non mente, ma mancano alcuni dettagli. Il dettaglio più importante è che, oltre a quelli GraphDefsalvati da tf.train.write_graph(), è necessario ricordare anche i nomi dei tensori che si desidera alimentare e recuperare durante l'esecuzione del grafico (punto 5 sopra).
Mrry

@mrry: ho provato a utilizzare l'esempio DeepDream di tensorflows. ma sembra che abbia bisogno di modelli pre-addestrati in formato pb! Ho eseguito l'esempio Cifar10, ma crea solo checkpoint! Non sono riuscito a trovare alcun file pb o altro! come posso convertire i miei checkpoint nel formato pb che usa l'esempio di deepdream?
Rika

2
@ Coderx7 Penso davvero che tu non possa convertire un .ckpt in un .pb poiché il checkpoint contiene solo i pesi e le variabili e non sa nulla della struttura del grafico
davidivad

1
c'è un semplice codice per caricare un file .pb e poi eseguirlo?
Kong

1

Puoi provare il codice seguente:

with tf.gfile.FastGFile('model/frozen_inference_graph.pb', "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    g_in = tf.import_graph_def(graph_def, name="")
sess = tf.Session(graph=g_in)
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.