Tensorflow: come salvare / ripristinare un modello?


553

Dopo aver addestrato un modello a Tensorflow:

  1. Come si salva il modello addestrato?
  2. Come si ripristina in seguito questo modello salvato?

Sei riuscito a ripristinare le variabili utilizzate nel modello iniziale? Sto anche provando lo stesso identico problema, ma non sono in grado di scrivere una serie di variabili che sono state utilizzate durante l'allenamento del modello di avvio (di cui ho il file ckpt)
exAres

Non ho provato con il modello iniziale. Hai la struttura di rete del modello con i suoi nomi? Devi replicare la rete e quindi caricare pesi e distorsioni (il file ckpt) come spiega Ryan. Forse qualcosa è cambiato dal 15 novembre e ora c'è un approccio più semplice, non ne sono sicuro
mathetes,

Ohh va bene Ho precedentemente caricato altri modelli tensorflow pre-addestrati, ma cercavo specifiche variabili del modello iniziale. Grazie.
exAres

1
Se ripristini per continuare ad allenarti, utilizza solo i punti di controllo Risparmio. Se si salva il modello per fare riferimento, solo le API SavedModel di tensorflow.
HY G

Inoltre, se stai usando LSTM, avrai una mappa da stringa a un elenco di caratteri, assicurati di salvare e caricare tale elenco nello stesso ordine! Ciò non viene coperto salvando i pesi del modello e la rete dei grafici del modello e sembrerà che il modello non sia stato caricato quando si cambiano le sessioni o le modifiche dei dati.
devssh,

Risposte:


119

Documenti

Dai documenti:

Salva

# Create some variables.
v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
  sess.run(init_op)
  # Do some work with the model.
  inc_v1.op.run()
  dec_v2.op.run()
  # Save the variables to disk.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print("Model saved in path: %s" % save_path)

Ristabilire

tf.reset_default_graph()

# Create some variables.
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Check the values of the variables
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

Tensorflow 2

Questo è ancora beta, quindi sconsiglio per ora. Se vuoi ancora percorrere questa strada ecco la tf.saved_modelguida all'uso

Tensorflow <2

simple_save

Molte buone risposte, per completezza aggiungerò i miei 2 centesimi: simple_save . Anche un esempio di codice autonomo che utilizza l' tf.data.DatasetAPI.

Python 3; Tensorflow 1.14

import tensorflow as tf
from tensorflow.saved_model import tag_constants

with tf.Graph().as_default():
    with tf.Session() as sess:
        ...

        # Saving
        inputs = {
            "batch_size_placeholder": batch_size_placeholder,
            "features_placeholder": features_placeholder,
            "labels_placeholder": labels_placeholder,
        }
        outputs = {"prediction": model_output}
        tf.saved_model.simple_save(
            sess, 'path/to/your/location/', inputs, outputs
        )

Ripristino:

graph = tf.Graph()
with restored_graph.as_default():
    with tf.Session() as sess:
        tf.saved_model.loader.load(
            sess,
            [tag_constants.SERVING],
            'path/to/your/location/',
        )
        batch_size_placeholder = graph.get_tensor_by_name('batch_size_placeholder:0')
        features_placeholder = graph.get_tensor_by_name('features_placeholder:0')
        labels_placeholder = graph.get_tensor_by_name('labels_placeholder:0')
        prediction = restored_graph.get_tensor_by_name('dense/BiasAdd:0')

        sess.run(prediction, feed_dict={
            batch_size_placeholder: some_value,
            features_placeholder: some_other_value,
            labels_placeholder: another_value
        })

Esempio autonomo

Post del blog originale

Il codice seguente genera dati casuali per il bene della dimostrazione.

  1. Iniziamo creando i segnaposto. Manterranno i dati in fase di esecuzione. Da loro, creiamo il Datasete poi il suo Iterator. Otteniamo il tensore generato dall'iteratore, chiamato input_tensorche servirà come input per il nostro modello.
  2. Il modello stesso è costruito da input_tensor: un RNN bidirezionale basato su GRU seguito da un denso classificatore. Perché perché no.
  3. La perdita è una softmax_cross_entropy_with_logits, ottimizzata con Adam. Dopo 2 epoche (di 2 lotti ciascuno), salviamo il modello "addestrato" con tf.saved_model.simple_save. Se si esegue il codice così com'è, il modello verrà salvato in una cartella chiamata simple/nella directory di lavoro corrente.
  4. In un nuovo grafico, ripristiniamo quindi il modello salvato con tf.saved_model.loader.load. Prendiamo i segnaposto e i log con graph.get_tensor_by_namee l' Iteratoroperazione di inizializzazione con graph.get_operation_by_name.
  5. Infine, eseguiamo un'inferenza per entrambi i batch nel set di dati e controlliamo che il modello salvato e ripristinato restituisca entrambi gli stessi valori. Loro fanno!

Codice:

import os
import shutil
import numpy as np
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants


def model(graph, input_tensor):
    """Create the model which consists of
    a bidirectional rnn (GRU(10)) followed by a dense classifier

    Args:
        graph (tf.Graph): Tensors' graph
        input_tensor (tf.Tensor): Tensor fed as input to the model

    Returns:
        tf.Tensor: the model's output layer Tensor
    """
    cell = tf.nn.rnn_cell.GRUCell(10)
    with graph.as_default():
        ((fw_outputs, bw_outputs), (fw_state, bw_state)) = tf.nn.bidirectional_dynamic_rnn(
            cell_fw=cell,
            cell_bw=cell,
            inputs=input_tensor,
            sequence_length=[10] * 32,
            dtype=tf.float32,
            swap_memory=True,
            scope=None)
        outputs = tf.concat((fw_outputs, bw_outputs), 2)
        mean = tf.reduce_mean(outputs, axis=1)
        dense = tf.layers.dense(mean, 5, activation=None)

        return dense


def get_opt_op(graph, logits, labels_tensor):
    """Create optimization operation from model's logits and labels

    Args:
        graph (tf.Graph): Tensors' graph
        logits (tf.Tensor): The model's output without activation
        labels_tensor (tf.Tensor): Target labels

    Returns:
        tf.Operation: the operation performing a stem of Adam optimizer
    """
    with graph.as_default():
        with tf.variable_scope('loss'):
            loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
                    logits=logits, labels=labels_tensor, name='xent'),
                    name="mean-xent"
                    )
        with tf.variable_scope('optimizer'):
            opt_op = tf.train.AdamOptimizer(1e-2).minimize(loss)
        return opt_op


if __name__ == '__main__':
    # Set random seed for reproducibility
    # and create synthetic data
    np.random.seed(0)
    features = np.random.randn(64, 10, 30)
    labels = np.eye(5)[np.random.randint(0, 5, (64,))]

    graph1 = tf.Graph()
    with graph1.as_default():
        # Random seed for reproducibility
        tf.set_random_seed(0)
        # Placeholders
        batch_size_ph = tf.placeholder(tf.int64, name='batch_size_ph')
        features_data_ph = tf.placeholder(tf.float32, [None, None, 30], 'features_data_ph')
        labels_data_ph = tf.placeholder(tf.int32, [None, 5], 'labels_data_ph')
        # Dataset
        dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
        dataset = dataset.batch(batch_size_ph)
        iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
        dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')
        input_tensor, labels_tensor = iterator.get_next()

        # Model
        logits = model(graph1, input_tensor)
        # Optimization
        opt_op = get_opt_op(graph1, logits, labels_tensor)

        with tf.Session(graph=graph1) as sess:
            # Initialize variables
            tf.global_variables_initializer().run(session=sess)
            for epoch in range(3):
                batch = 0
                # Initialize dataset (could feed epochs in Dataset.repeat(epochs))
                sess.run(
                    dataset_init_op,
                    feed_dict={
                        features_data_ph: features,
                        labels_data_ph: labels,
                        batch_size_ph: 32
                    })
                values = []
                while True:
                    try:
                        if epoch < 2:
                            # Training
                            _, value = sess.run([opt_op, logits])
                            print('Epoch {}, batch {} | Sample value: {}'.format(epoch, batch, value[0]))
                            batch += 1
                        else:
                            # Final inference
                            values.append(sess.run(logits))
                            print('Epoch {}, batch {} | Final inference | Sample value: {}'.format(epoch, batch, values[-1][0]))
                            batch += 1
                    except tf.errors.OutOfRangeError:
                        break
            # Save model state
            print('\nSaving...')
            cwd = os.getcwd()
            path = os.path.join(cwd, 'simple')
            shutil.rmtree(path, ignore_errors=True)
            inputs_dict = {
                "batch_size_ph": batch_size_ph,
                "features_data_ph": features_data_ph,
                "labels_data_ph": labels_data_ph
            }
            outputs_dict = {
                "logits": logits
            }
            tf.saved_model.simple_save(
                sess, path, inputs_dict, outputs_dict
            )
            print('Ok')
    # Restoring
    graph2 = tf.Graph()
    with graph2.as_default():
        with tf.Session(graph=graph2) as sess:
            # Restore saved values
            print('\nRestoring...')
            tf.saved_model.loader.load(
                sess,
                [tag_constants.SERVING],
                path
            )
            print('Ok')
            # Get restored placeholders
            labels_data_ph = graph2.get_tensor_by_name('labels_data_ph:0')
            features_data_ph = graph2.get_tensor_by_name('features_data_ph:0')
            batch_size_ph = graph2.get_tensor_by_name('batch_size_ph:0')
            # Get restored model output
            restored_logits = graph2.get_tensor_by_name('dense/BiasAdd:0')
            # Get dataset initializing operation
            dataset_init_op = graph2.get_operation_by_name('dataset_init')

            # Initialize restored dataset
            sess.run(
                dataset_init_op,
                feed_dict={
                    features_data_ph: features,
                    labels_data_ph: labels,
                    batch_size_ph: 32
                }

            )
            # Compute inference for both batches in dataset
            restored_values = []
            for i in range(2):
                restored_values.append(sess.run(restored_logits))
                print('Restored values: ', restored_values[i][0])

    # Check if original inference and restored inference are equal
    valid = all((v == rv).all() for v, rv in zip(values, restored_values))
    print('\nInferences match: ', valid)

Questo stamperà:

$ python3 save_and_restore.py

Epoch 0, batch 0 | Sample value: [-0.13851789 -0.3087595   0.12804556  0.20013677 -0.08229901]
Epoch 0, batch 1 | Sample value: [-0.00555491 -0.04339041 -0.05111827 -0.2480045  -0.00107776]
Epoch 1, batch 0 | Sample value: [-0.19321944 -0.2104792  -0.00602257  0.07465433  0.11674127]
Epoch 1, batch 1 | Sample value: [-0.05275984  0.05981954 -0.15913513 -0.3244143   0.10673307]
Epoch 2, batch 0 | Final inference | Sample value: [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Epoch 2, batch 1 | Final inference | Sample value: [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Saving...
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: b'/some/path/simple/saved_model.pb'
Ok

Restoring...
INFO:tensorflow:Restoring parameters from b'/some/path/simple/variables/variables'
Ok
Restored values:  [-0.26331693 -0.13013336 -0.12553    -0.04276478  0.2933622 ]
Restored values:  [-0.07730117  0.11119192 -0.20817074 -0.35660955  0.16990358]

Inferences match:  True

1
Sono principiante e ho bisogno di ulteriori spiegazioni ...: Se ho un modello CNN, dovrei memorizzare solo 1. input_placeholder 2. labels_placeholder e 3. output_of_cnn? O tutto l'intermedio tf.contrib.layers?
Piove il

2
Il grafico è stato completamente ripristinato. Potresti controllarlo in esecuzione [n.name for n in graph2.as_graph_def().node]. Come dice la documentazione, il semplice salvataggio ha lo scopo di semplificare l'interazione con il servizio di tensorflow, questo è il punto degli argomenti; altre variabili vengono comunque ripristinate, altrimenti l'inferenza non avverrebbe. Prendi le tue variabili di interesse come ho fatto nell'esempio. Controlla la documentazione
datata

@ted quando dovrei usare tf.saved_model.simple_save vs tf.train.Saver ()? Dal mio intuito userei tf.train.Saver () durante l'allenamento e per memorizzare diversi momenti nel tempo. Vorrei usare tf.saved_model.simple_save quando l'addestramento è fatto per l'uso in produzione. (Ho chiesto lo stesso anche in un commento qui )
loco.loop

1
Bello immagino, ma funziona anche con i modelli in modalità Eager e tfe.Saver?
Geoffrey Anderson,

1
senza global_stepargomentazioni, se ti fermi, prova a riprendere ad allenarti, penserà di essere un passo uno. Almeno rovinerà le visualizzazioni del tuo tensorboard
Monica Heddneck,

252

Sto migliorando la mia risposta per aggiungere ulteriori dettagli per il salvataggio e il ripristino di modelli.

Nella (e dopo) versione 0.11 di Tensorflow :

Salva il modello:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Ripristina il modello:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

Questo e alcuni casi d'uso più avanzati sono stati spiegati molto bene qui.

Un breve tutorial completo per salvare e ripristinare i modelli Tensorflow


3
+1 per questo # Accedi alle variabili salvate direttamente stampa (sess.run ('bias: 0')) # Questo stamperà 2, che è il valore del bias che abbiamo salvato. Aiuta a scopi di debug per vedere se il modello è caricato correttamente. le variabili possono essere ottenute con "All_varaibles = tf.get_collection (tf.GraphKeys.GLOBAL_VARIABLES". Inoltre, "sess.run (tf.global_variables_initializer ())" deve essere prima del ripristino.
LGG

1
Sei sicuro di dover eseguire di nuovo global_variables_initializer? Ho ripristinato il mio grafico con global_variable_initialization e mi dà ogni volta un output diverso sugli stessi dati. Quindi ho commentato l'inizializzazione e ho appena ripristinato il grafico, la variabile di input e le operazioni, e ora funziona bene.
Aditya Shinde,

@AdityaShinde Non capisco perché ottengo sempre valori diversi ogni volta. E non ho incluso il passaggio di inizializzazione variabile per il ripristino. Sto usando il mio codice tra l'altro.
Chaine,

@AdityaShinde: non è necessario init op poiché i valori sono già inizializzati dalla funzione di ripristino, quindi è stato rimosso. Tuttavia, non sono sicuro del motivo per cui hai ottenuto un output diverso utilizzando init op.
affondò l'

5
@sankit Quando ripristini i tensori perché aggiungi :0ai nomi?
Sahar Rabinoviz,

177

Nella (e dopo) versione 0.11.0RC1 di TensorFlow, è possibile salvare e ripristinare il modello direttamente chiamando tf.train.export_meta_graphe tf.train.import_meta_graphsecondo https://www.tensorflow.org/programmers_guide/meta_graph .

Salva il modello

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Ripristina il modello

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

4
come caricare le variabili dal modello salvato? Come copiare i valori in qualche altra variabile?
Neel,

9
Non riesco a far funzionare questo codice. Il modello viene salvato ma non riesco a ripristinarlo. Mi sta dando questo errore. <built-in function TF_Run> returned a result with an error set
Saad Qureshi,

2
Quando dopo il ripristino accedo alle variabili come mostrato sopra, funziona. Ma non riesco a ottenere le variabili più direttamente usando tf.get_variable_scope().reuse_variables()seguito da var = tf.get_variable("varname"). Questo mi dà l'errore: "ValueError: la variabile varname non esiste o non è stata creata con tf.get_variable ()." Perché? Questo non dovrebbe essere possibile?
Johann Petrak,

4
Funziona bene solo per le variabili, ma come si può ottenere l'accesso a un segnaposto e alimentare i valori dopo aver ripristinato il grafico?
kbrose,

11
Questo mostra solo come ripristinare le variabili. Come è possibile ripristinare l'intero modello e testarlo su nuovi dati senza ridefinire la rete?
Chaine,

127

Per la versione TensorFlow <0.11.0RC1:

I punti di controllo salvati contengono valori per Variable s nel modello, non per il modello / grafico stesso, il che significa che il grafico dovrebbe essere lo stesso quando si ripristina il punto di controllo.

Ecco un esempio di regressione lineare in cui è presente un ciclo di addestramento che salva i punti di controllo variabili e una sezione di valutazione che ripristinerà le variabili salvate in una corsa precedente e calcolerà le previsioni. Naturalmente, puoi anche ripristinare le variabili e continuare l'allenamento, se lo desideri.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Ecco i documenti per Variables, che riguardano il salvataggio e il ripristino. Ed ecco i documenti per il Saver.


1
I FLAG sono definiti dall'utente. Ecco un esempio per definirli: github.com/tensorflow/tensorflow/blob/master/tensorflow/…
Ryan Sepassi,

in quale formato batch_xdeve essere? Binario? Matrice numpy?
pepe,

@pepe Numpy arrary dovrebbe andare bene. E il tipo di elemento dovrebbe corrispondere al tipo di segnaposto. [link] tensorflow.org/versions/r0.9/api_docs/python/…
Donny

FLAGS dà errore undefined. Puoi dirmi qual è la definizione di FLAG per questo codice. @RyanSepassi
Muhammad Hannan,

Per rendere più esplicito: versioni recenti di tensorflow non consentono di memorizzare il modello / grafico. [Non è chiaro per me quali aspetti della risposta si applicano al vincolo <0.11. Dato il gran numero di voti, sono stato tentato di credere che questa affermazione generale sia ancora vera per le versioni recenti.]
bluenote10

78

Il mio ambiente: Python 3.6, Tensorflow 1.3.0

Sebbene ci siano state molte soluzioni, la maggior parte di esse si basa su tf.train.Saver. Quando carichiamo un .ckptsalvati da Saver, dobbiamo ridefinire sia la rete tensorflow o utilizzare qualche nome strano e difficile ricordare, ad esempio 'placehold_0:0', 'dense/Adam/Weight:0'. Qui ti consiglio di usare tf.saved_model, un esempio più semplice riportato di seguito, puoi saperne di più dal servizio di un modello TensorFlow :

Salva il modello:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Carica il modello:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

4
+1 per un ottimo esempio dell'API SavedModel. Tuttavia, vorrei che la tua sezione Salva il modello mostrasse un ciclo di allenamento come la risposta di Ryan Sepassi! Mi rendo conto che questa è una vecchia domanda, ma questa risposta è uno dei pochi (e preziosi) esempi di SavedModel che ho trovato su Google.
Dylan F,

@Tom questa è un'ottima risposta - solo una volta rivolta al nuovo SavedModel. Potresti dare un'occhiata a questa domanda SavedModel? stackoverflow.com/questions/48540744/...
Bluesummers

Ora fai funzionare tutto correttamente con i modelli TF Eager. Nella loro presentazione del 2018, Google ha consigliato a tutti di allontanarsi dal codice grafico TF.
Geoffrey Anderson,

55

Ci sono due parti nel modello, la definizione del modello, salvata da Supervisorcome graph.pbtxtnella directory del modello e i valori numerici dei tensori, salvati in file di checkpoint come model.ckpt-1003418.

La definizione del modello può essere ripristinata utilizzando tf.import_graph_defe i pesi vengono ripristinati utilizzando Saver.

Tuttavia, Saverutilizza un elenco speciale di raccolte che contiene variabili associate al modello Graph e questa raccolta non è inizializzata utilizzando import_graph_def, quindi al momento non è possibile utilizzare le due insieme (è sulla nostra tabella di marcia da correggere). Per ora, è necessario utilizzare l'approccio di Ryan Sepassi: costruire manualmente un grafico con nomi di nodo identici e utilizzare Saverper caricare i pesi in esso.

(In alternativa, è possibile hackerarlo utilizzando import_graph_def, creando variabili manualmente e utilizzando tf.add_to_collection(tf.GraphKeys.VARIABLES, variable)per ogni variabile, quindi utilizzando Saver)


Nell'esempio classify_image.py che utilizza inceptionv3, viene caricato solo graphdef. Significa che ora GraphDef contiene anche la variabile?
gennaio

1
@jrabary Il modello è stato probabilmente congelato .
Eric Platon,

1
Ciao, sono nuovo di Tensorflow e ho problemi a salvare il mio modello. Ti sarei davvero grato se mi potessi aiutare stackoverflow.com/questions/48083474/…
Ruchir Baronia,

39

Puoi anche prendere questo modo più semplice.

Passaggio 1: inizializza tutte le variabili

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Passaggio 2: salvare la sessione all'interno del modello Savere salvarla

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Passaggio 3: ripristinare il modello

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Passaggio 4: controlla la tua variabile

W1 = session.run(W1)
print(W1)

Durante l'esecuzione in diverse istanze di Python, utilizzare

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

Ciao, come posso salvare il modello dopo supponiamo 3000 iterazioni, simili a Caffe. Ho scoperto che tensorflow salva solo gli ultimi modelli nonostante io concateni il numero di iterazione con il modello per differenziarlo tra tutte le iterazioni. Intendo model_3000.ckpt, model_6000.ckpt, --- model_100000.ckpt. Puoi gentilmente spiegare perché non salva tutto piuttosto salva solo le ultime 3 iterazioni.
Khan,


3
Esiste un metodo per ottenere tutti i nomi di variabili / operazioni salvati nel grafico?
Moondra,

21

Nella maggior parte dei casi, il salvataggio e il ripristino da disco utilizzando a tf.train.Saverè l'opzione migliore:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Puoi anche salvare / ripristinare la struttura grafica stessa (vedi la documentazione MetaGraph per i dettagli). Per impostazione predefinita, Saversalva la struttura del grafico in un .metafile. Puoi chiamare import_meta_graph()per ripristinarlo. Ripristina la struttura del grafico e restituisce un Saverche è possibile utilizzare per ripristinare lo stato del modello:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

Tuttavia, ci sono casi in cui hai bisogno di qualcosa di molto più veloce. Ad esempio, se si implementa l'arresto anticipato, si desidera salvare i punti di controllo ogni volta che il modello migliora durante l'allenamento (come misurato sul set di convalida), quindi se non ci sono progressi per un certo periodo di tempo, si desidera tornare al modello migliore. Se salvi il modello su disco ogni volta che migliora, rallenterà enormemente l'allenamento. Il trucco è salvare gli stati variabili in memoria , quindi ripristinarli in seguito:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

Una rapida spiegazione: quando si crea una variabile X, TensorFlow crea automaticamente un'operazione di assegnazione X/Assignper impostare il valore iniziale della variabile. Invece di creare segnaposto e operazioni di assegnazione extra (che renderebbero il grafico disordinato), utilizziamo solo queste operazioni di assegnazione esistenti. Il primo input di ciascuna assegnazione op è un riferimento alla variabile che dovrebbe inizializzare e il secondo input ( assign_op.inputs[1]) è il valore iniziale. Pertanto, al fine di impostare qualsiasi valore desiderato (anziché il valore iniziale), è necessario utilizzare a feed_dicte sostituire il valore iniziale. Sì, TensorFlow ti consente di inserire un valore per qualsiasi operazione, non solo per i segnaposto, quindi funziona perfettamente.


Grazie per la risposta. Ho una domanda simile su come convertire un singolo file .ckpt in due file .index e .data (diciamo per i modelli di inizio pre-addestrati disponibili su tf.slim). La mia domanda è qui: stackoverflow.com/questions/47762114/…
Amir

Ciao, sono nuovo di Tensorflow e ho problemi a salvare il mio modello. Ti sarei davvero grato se mi potessi aiutare stackoverflow.com/questions/48083474/…
Ruchir Baronia,

17

Come ha detto Yaroslav, puoi hackerare il ripristino da un graph_def e un checkpoint importando il grafico, creando manualmente variabili e quindi usando un Saver.

L'ho implementato per uso personale, quindi ho pensato di condividere il codice qui.

link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(Questo è, ovviamente, un trucco e non vi è alcuna garanzia che i modelli salvati in questo modo rimangano leggibili nelle versioni future di TensorFlow.)


14

Se si tratta di un modello salvato internamente, è sufficiente specificare un restauratore per tutte le variabili come

restorer = tf.train.Saver(tf.all_variables())

e usalo per ripristinare le variabili in una sessione corrente:

restorer.restore(self._sess, model_file)

Per il modello esterno è necessario specificare la mappatura dai nomi delle sue variabili ai nomi delle variabili. È possibile visualizzare i nomi delle variabili del modello utilizzando il comando

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

Lo script inspect_checkpoint.py si trova nella cartella './tensorflow/python/tools' del sorgente Tensorflow.

Per specificare la mappatura, è possibile utilizzare il mio Tensorflow-Worklab , che contiene una serie di classi e script per addestrare e riqualificare diversi modelli. Include un esempio di riqualificazione dei modelli ResNet, disponibile qui


all_variables()è ora deprecato
MiniQuark

Ciao, sono nuovo di Tensorflow e ho problemi a salvare il mio modello. Ti sarei davvero grato se mi potessi aiutare stackoverflow.com/questions/48083474/…
Ruchir Baronia,

12

Ecco la mia semplice soluzione per i due casi di base che differiscono sul fatto che si desideri caricare il grafico dal file o crearlo durante il runtime.

Questa risposta vale per Tensorflow 0.12+ (incluso 1.0).

Ricostruzione del grafico nel codice

Salvataggio

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Caricamento in corso

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Caricamento anche del grafico da un file

Quando usi questa tecnica, assicurati che tutti i tuoi strati / variabili abbiano esplicitamente impostato nomi univoci.Altrimenti Tensorflow renderà i nomi univoci e saranno quindi diversi dai nomi memorizzati nel file. Non è un problema nella tecnica precedente, perché i nomi sono "alterati" allo stesso modo sia nel caricamento che nel salvataggio.

Salvataggio

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Caricamento in corso

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

-1 Iniziare la risposta eliminando "tutte le altre risposte qui" è un po 'dura. Detto questo, ho effettuato il downgrade per altri motivi: dovresti assolutamente salvare tutte le variabili globali, non solo le variabili addestrabili. Ad esempio, la global_stepvariabile e le medie mobili della normalizzazione batch sono variabili non addestrabili, ma vale sicuramente la pena salvarle entrambe. Inoltre, dovresti distinguere più chiaramente la costruzione del grafico dall'esecuzione della sessione, ad esempio Saver(...).save()creerà nuovi nodi ogni volta che lo eseguirai. Probabilmente non quello che vuoi. E c'è di più ...: /
MiniQuark

@MiniQuark ok, grazie per il tuo feedback, modificherò la risposta in base ai tuoi suggerimenti;)
Martin Pecka

10

Puoi anche dare un'occhiata agli esempi in TensorFlow / skflow , che offre savee restoremetodi che possono aiutarti a gestire facilmente i tuoi modelli. Ha parametri che puoi anche controllare con quale frequenza desideri eseguire il backup del tuo modello.


9

Se usi tf.train.MonitoredTrainingSession come sessione predefinita, non è necessario aggiungere altro codice per salvare / ripristinare le cose. Basta passare un nome dir di checkpoint al costruttore di MonitoredTrainingSession, che utilizzerà gli hook di sessione per gestirli.


utilizzando tf.train.Supervisor gestirà la creazione di tale sessione per te e fornirà una soluzione più completa.
Segna il

1
@Mark tf.train.Supervisor è deprecato
Changming Sun

Hai qualche link a supporto del reclamo che il Supervisore è deprecato? Non ho visto nulla che indichi che questo è il caso.
Segna l'


Grazie per l'URL: ho verificato con l'origine originale delle informazioni e mi è stato detto che probabilmente rimarrà in circolazione fino alla fine della serie TF 1.x, ma nessuna garanzia dopo.
Segna il

8

Tutte le risposte qui sono fantastiche, ma voglio aggiungere due cose.

In primo luogo, per elaborare la risposta di @ user7505159, "./" può essere importante aggiungere all'inizio del nome del file che si sta ripristinando.

Ad esempio, puoi salvare un grafico senza "./" nel nome del file in questo modo:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

Ma per ripristinare il grafico, potrebbe essere necessario anteporre un "./" al nome_file:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

Non sarà sempre necessario "./", ma può causare problemi a seconda dell'ambiente e della versione di TensorFlow.

Vuole anche menzionare che sess.run(tf.global_variables_initializer())può essere importante prima di ripristinare la sessione.

Se ricevi un errore relativo alle variabili non inizializzate quando tenti di ripristinare una sessione salvata, assicurati di includerlo sess.run(tf.global_variables_initializer())prima della saver.restore(sess, save_file)riga. Ti può far venire il mal di testa.


7

Come descritto nel numero 6255 :

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

invece di

saver.restore('my_model_final.ckpt')

7

Secondo la nuova versione di Tensorflow, tf.train.Checkpointè il modo preferibile per salvare e ripristinare un modello:

Checkpoint.savee Checkpoint.restorescrivere e leggere checkpoint basati su oggetti, in contrasto con tf.train.Saver che scrive e legge checkpoint basati su variabile.name. Il checkpoint basato sugli oggetti salva un grafico delle dipendenze tra gli oggetti Python (Layer, Optimizer, Variabili, ecc.) Con bordi denominati e questo grafico viene utilizzato per abbinare le variabili durante il ripristino di un checkpoint. Può essere più robusto per le modifiche nel programma Python e aiuta a supportare il ripristino su creazione per le variabili quando si esegue con impazienza. Preferisci tf.train.Checkpointoltre tf.train.Saverper il nuovo codice .

Ecco un esempio:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

Maggiori informazioni ed esempio qui.


7

Per tensorflow 2.0 , è semplice come

# Save the model
model.save('path_to_my_model.h5')

Ripristinare:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

Che dire di tutte le operazioni personalizzate tf e le variabili che non fanno parte dell'oggetto modello? Verranno salvati in qualche modo quando chiami save () sul modello? Ho varie espressioni di perdita personalizzate e probabilità di tensorflow che vengono utilizzate nella rete di inferenza e generazione ma non fanno parte del mio modello. Il mio oggetto modello keras contiene solo i livelli denso e conv. In TF 1 ho appena chiamato il metodo save e potrei essere sicuro che tutte le operazioni e i tensori utilizzati nel mio grafico verrebbero salvati. In TF2 non vedo come verranno salvate le operazioni che in qualche modo non vengono aggiunte al modello di keras.
Kristof,

Ci sono altre informazioni sul ripristino dei modelli in TF 2.0? Non posso ripristinare i pesi da file di checkpoint generati tramite l'API C, vedi: stackoverflow.com/questions/57944786/...
jregalad


5

tf.keras Salvataggio del modello con TF2.0

Vedo grandi risposte per il salvataggio di modelli con TF1.x. Voglio fornire un paio di ulteriori indicazioni per il salvataggiotensorflow.keras modelli, il che è un po 'complicato in quanto ci sono molti modi per salvare un modello.

Qui sto fornendo un esempio di salvataggio di un tensorflow.kerasmodello nella model_pathcartella nella directory corrente. Funziona bene con il più recente tensorflow (TF2.0). Aggiornerò questa descrizione in caso di cambiamenti nel prossimo futuro.

Salvataggio e caricamento dell'intero modello

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Salvataggio e caricamento solo dei pesi modello

Se sei interessato a salvare solo i pesi del modello e quindi caricare i pesi per ripristinare il modello, allora

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Salvataggio e ripristino mediante callback del checkpoint di keras

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

modello di salvataggio con metriche personalizzate

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Salvataggio del modello di keras con operazioni personalizzate

Quando abbiamo operazioni personalizzate come nel caso seguente ( tf.tile), dobbiamo creare una funzione e concludere con un livello Lambda. Altrimenti, il modello non può essere salvato.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

Penso di aver coperto alcuni dei molti modi per salvare il modello tf.keras. Tuttavia, ci sono molti altri modi. Commenta di seguito se vedi che il tuo caso d'uso non è coperto sopra. Grazie!


3

Utilizzare tf.train.Saver per salvare un modello, promemoria, è necessario specificare var_list, se si desidera ridurre le dimensioni del modello. Val_list può essere tf.trainable_variables o tf.global_variables.


3

È possibile salvare le variabili nella rete utilizzando

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

Per ripristinare la rete per il riutilizzo in un secondo momento o in un altro script, utilizzare:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Punti importanti:

  1. sess deve essere lo stesso tra la prima e la successiva (struttura coerente).
  2. saver.restore necessita del percorso della cartella dei file salvati, non di un singolo percorso del file.

2

Ovunque tu voglia salvare il modello,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Assicurati che tutti tf.Variableabbiano dei nomi, perché potresti volerli ripristinare in seguito usando i loro nomi. E dove vuoi prevedere,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Assicurarsi che il risparmiatore venga eseguito nella sessione corrispondente. Ricordare che, se si utilizza il tf.train.latest_checkpoint('./'), verrà utilizzato solo l'ultimo punto di controllo.


2

Sono sulla versione:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Il modo semplice è

Salva:

model.save("model.h5")

Ristabilire:

model = tf.keras.models.load_model("model.h5")

2

Per tensorflow-2.0

è molto semplice.

import tensorflow as tf

SALVA

model.save("model_name")

RISTABILIRE

model = tf.keras.models.load_model('model_name')

1

Seguendo la risposta di @Vishnuvardhan Janapati, ecco un altro modo per salvare e ricaricare il modello con layer / metrica / perdita personalizzati in TensorFlow 2.0.0

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

In questo modo, una volta eseguiti tali codici e salvato il modello con tf.keras.models.save_modelo model.saveo ModelCheckpointcallback, è possibile ricaricare il modello senza la necessità di oggetti personalizzati precisi, semplici come

new_model = tf.keras.models.load_model("./model.h5"})

0

Nella nuova versione di tensorflow 2.0, il processo di salvataggio / caricamento di un modello è molto più semplice. A causa dell'implementazione dell'API di Keras, un'API di alto livello per TensorFlow.

Per salvare un modello: consultare la documentazione di riferimento: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

Per caricare un modello:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

0

Ecco un semplice esempio usando il formato SavedModel di Tensorflow 2.0 (che è il formato consigliato, secondo i documenti ) per un semplice classificatore di set di dati MNIST, usando l'API funzionale di Keras senza troppa fantasia:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

Che cosa è serving_default?

È il nome della firma def del tag selezionato (in questo caso è servestato selezionato il tag predefinito ). Inoltre, qui spiega come trovare i tag e le firme di un modello usando saved_model_cli.

Avvertenze

Questo è solo un esempio di base se vuoi solo metterlo in funzione, ma non è affatto una risposta completa - forse potrei aggiornarlo in futuro. Volevo solo fare un semplice esempio usando ilSavedModel in TF 2.0 perché non ne ho visto uno, neanche così semplice, da nessuna parte.

La risposta di @ Tom è un esempio di SavedModel, ma non funzionerà su Tensorflow 2.0, perché purtroppo ci sono alcune modifiche.

@ Vishnuvardhan La risposta di Janapati dice TF 2.0, ma non è per il formato SavedModel.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.