Come elencare tutte le operazioni utilizzate in Tensorflow SavedModel?


10

Se salvo il mio modello utilizzando la tensorflow.saved_model.savefunzione nel formato SavedModel, come posso recuperare quali Tensorflow Ops vengono utilizzati in questo modello in seguito. Poiché il modello può essere ripristinato, queste operazioni sono memorizzate nel grafico, la mia ipotesi è nel saved_model.pbfile. Se carico questo protobuf (quindi non l'intero modello), la parte della libreria del protobuf li elenca, ma questo non è documentato e etichettato come funzionalità sperimentale per ora. I modelli creati in Tensorflow 1.x non avranno questa parte.

Quindi qual è un modo rapido e affidabile per recuperare un elenco di operazioni usate (come MatchingFileso WriteFile) da un modello in formato SavedModel?

In questo momento posso congelare l'intera cosa, come tensorflowjs-converterfa. Poiché controllano anche le operazioni supportate. Questo attualmente non funziona quando nel modello è presente un LSTM, vedere qui . C'è un modo migliore per farlo, dato che gli Ops sono sicuramente lì dentro?

Un modello di esempio:

class FileReader(tf.Module):

@tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
def read_disk(self, file_name):
    input_scalar = tf.reshape(file_name, [])
    output = tf.io.read_file(input_scalar)
    return tf.stack([output], name='content')

file_reader = FileReader()

tf.saved_model.save(file_reader, 'file_reader')

Attesi in uscita tutti gli Op, contenenti in questo caso almeno:

  • ReadFilecome descritto qui
  • ...

1
È difficile dire esattamente cosa vuoi, cos'è saved_model.pb, è tf.GraphDefun SavedModelmessaggio o un messaggio protobuf? Se hai tf.GraphDefchiamato gd, puoi ottenere l'elenco delle operazioni usate con sorted(set(n.op for n in gd.node)). Se hai un modello caricato, puoi farlo sorted(set(op.type for op in tf.get_default_graph().get_operations())). Se è un SavedModel, puoi ottenerlo tf.GraphDefda (ad esempio saved_model.meta_graphs[0].graph_def).
jdehesa,

Voglio recuperare le operazioni da un SavedModel memorizzato. Quindi, in effetti, l'ultima opzione che stai descrivendo. Qual è la saved_modelvariabile nel tuo ultimo esempio? Il risultato tf.saved_model.load('/path/to/model')o il caricamento del protobuf del file saved_model.pb.
campionatori

Risposte:


1

Se saved_model.pbè un SavedModelmessaggio protobuf, allora ottieni le operazioni direttamente da lì. Diciamo che creiamo un modello come segue:

import tensorflow as tf

class FileReader(tf.Module):
    @tf.function(input_signature=[tf.TensorSpec(name='filename', shape=[None], dtype=tf.string)])
    def read_disk(self, file_name):
        input_scalar = tf.reshape(file_name, [])
        output = tf.io.read_file(input_scalar)
        return tf.stack([output], name='content')

file_reader = FileReader()
tf.saved_model.save(file_reader, 'tmp')

Ora possiamo trovare le operazioni utilizzate da quel modello in questo modo:

from tensorflow.core.protobuf.saved_model_pb2 import SavedModel

saved_model = SavedModel()
with open('tmp/saved_model.pb', 'rb') as f:
    saved_model.ParseFromString(f.read())
model_op_names = set()
# Iterate over every metagraph in case there is more than one
for meta_graph in saved_model.meta_graphs:
    # Add operations in the graph definition
    model_op_names.update(node.op for node in meta_graph.graph_def.node)
    # Go through the functions in the graph definition
    for func in meta_graph.graph_def.library.function:
        # Add operations in each function
        model_op_names.update(node.op for node in func.node_def)
# Convert to list, sorted if you want
model_op_names = sorted(model_op_names)
print(*model_op_names, sep='\n')
# Const
# Identity
# MergeV2Checkpoints
# NoOp
# Pack
# PartitionedCall
# Placeholder
# ReadFile
# Reshape
# RestoreV2
# SaveV2
# ShardedFilename
# StatefulPartitionedCall
# StringJoin

Ho provato qualcosa del genere, ma sfortunatamente questo non è quello che mi aspetto che faccia: supponiamo di avere un modello che fa questo: input_scalar = tf.reshape(file_name, []) output = tf.io.read_file(input_scalar) return tf.stack([output], name='content')Quindi l'op ReadFile come elencato qui è lì, ma non viene stampato.
campionatori

1
@sampers Ho modificato la risposta con un esempio come tu suggerisci. Ottengo l' ReadFileoperazione nell'output. È possibile che, nel tuo caso attuale, tale operazione non sia tra l'input e l'output del modello salvato? In tal caso, penso che potrebbe essere potata.
jdehesa,

In effetti con il modello dato funziona. Sfortunatamente per un modulo realizzato in tf2, non lo è. Se creo un modulo tf.Mule con 1 funzione con un'annotazione di file_nameargomento @tf.function, contenente le chiamate che ho elencato nel mio commento precedente, viene visualizzato il seguente elenco:Const, NoOp, PartitionedCall, Placeholder, StatefulPartitionedCall
campionatori

aggiunto un modello alla mia domanda
campionatori il

@sampers Ho aggiornato la mia risposta. Prima utilizzavo TF 1.x, non avevo familiarità con le modifiche agli oggetti di definizione del grafico in TF 2.x, penso che la risposta ora copra tutto ciò che riguarda il modello salvato. Penso che le operazioni corrispondenti alla funzione Python che hai scritto siano saved_model.meta_graphs[0].graph_def.library.function[0](la node_defraccolta all'interno di quell'oggetto funzione).
jdehesa,
Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.