Come tf.app.run()
funziona la traduzione demo di Tensorflow?
In tensorflow/models/rnn/translate/translate.py
, c'è una chiamata a tf.app.run()
. Come viene gestito?
if __name__ == "__main__":
tf.app.run()
Come tf.app.run()
funziona la traduzione demo di Tensorflow?
In tensorflow/models/rnn/translate/translate.py
, c'è una chiamata a tf.app.run()
. Come viene gestito?
if __name__ == "__main__":
tf.app.run()
Risposte:
if __name__ == "__main__":
significa che il file corrente viene eseguito sotto una shell anziché importato come modulo.
tf.app.run()
Come puoi vedere attraverso il file app.py
def run(main=None, argv=None):
"""Runs the program with an optional 'main' function and 'argv' list."""
f = flags.FLAGS
# Extract the args from the optional `argv` list.
args = argv[1:] if argv else None
# Parse the known flags from that list, or from the command
# line otherwise.
# pylint: disable=protected-access
flags_passthrough = f._parse_flags(args=args)
# pylint: enable=protected-access
main = main or sys.modules['__main__'].main
# Call the main function, passing through any arguments
# to the final program.
sys.exit(main(sys.argv[:1] + flags_passthrough))
Interrompiamo riga per riga:
flags_passthrough = f._parse_flags(args=args)
Ciò garantisce che l'argomento che passi attraverso la riga di comando sia valido, ad es.
python my_model.py --data_dir='...' --max_iteration=10000
In realtà, questa funzione è implementata in base al argparse
modulo standard di Python .
main = main or sys.modules['__main__'].main
Il primo main
a destra di =
è il primo argomento della funzione corrente run(main=None, argv=None)
. Mentre sys.modules['__main__']
significa file corrente in esecuzione (ad esempio my_model.py
).
Quindi ci sono due casi:
Non hai una main
funzione in my_model.py
Quindi devi chiamaretf.app.run(my_main_running_function)
hai una main
funzione in my_model.py
. (Questo è principalmente il caso.)
Ultima linea:
sys.exit(main(sys.argv[:1] + flags_passthrough))
assicura che la funzione main(argv)
o my_main_running_function(argv)
venga chiamata correttamente con argomenti analizzati.
abseil
cui TF deve aver assorbito abseil.io/docs/python/guides/flags
È solo un wrapper molto veloce che gestisce l'analisi delle bandiere e quindi le invia al tuo principale. Vedere il codice .
main = main or sys.modules['__main__'].main
e cosa sys.exit(main(sys.argv[:1] + flags_passthrough))
significano?
main()
?
Non c'è niente di speciale in tf.app
. Questo è solo uno script di entry point generico , che
Esegue il programma con una funzione 'principale' opzionale e un elenco 'argv'.
Non ha nulla a che fare con le reti neurali e chiama semplicemente la funzione principale, passando attraverso qualsiasi argomento.
In termini semplici, il compito di tf.app.run()
è innanzitutto impostare i flag globali per un utilizzo successivo come:
from tensorflow.python.platform import flags
f = flags.FLAGS
e quindi esegui la tua funzione principale personalizzata con una serie di argomenti.
Ad esempio, nella base di codici NMT TensorFlow , il primo punto di ingresso per l'esecuzione del programma per addestramento / inferenza inizia a questo punto (vedere il codice seguente)
if __name__ == "__main__":
nmt_parser = argparse.ArgumentParser()
add_arguments(nmt_parser)
FLAGS, unparsed = nmt_parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
Dopo aver analizzato gli argomenti usando argparse
, tf.app.run()
esegui con te la funzione "main" che è definita come:
def main(unused_argv):
default_hparams = create_hparams(FLAGS)
train_fn = train.train
inference_fn = inference.inference
run_main(FLAGS, default_hparams, train_fn, inference_fn)
Quindi, dopo aver impostato i flag per l'uso globale, tf.app.run()
esegue semplicemente quella main
funzione che gli si passa argv
come parametri.
PS: Come dice la risposta di Salvador Dalì , è solo una buona pratica di ingegneria del software, immagino, anche se non sono sicuro che TensorFlow esegua una corsa ottimizzata della main
funzione rispetto a quella eseguita con il normale CPython.
Il codice di Google dipende molto dal fatto che i flag globali accedono alle librerie / binari / script python e quindi tf.app.run () analizza quei flag per creare uno stato globale nella variabile FLAG (o qualcosa di simile) e quindi chiama python main ( ) come dovrebbe.
Se non avessero ricevuto questa chiamata a tf.app.run (), gli utenti potrebbero dimenticare di eseguire l'analisi dei FLAG, portando a queste librerie / binari / script che non hanno accesso ai FLAG di cui hanno bisogno.
2.0 Risposta compatibile : se si desidera utilizzare tf.app.run()
in Tensorflow 2.0
, dovremmo usare il comando,
tf.compat.v1.app.run()
oppure puoi usare tf_upgrade_v2
per convertire il 1.x
codice in 2.0
.
tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.')
e quindi, se li usitf.app.run()
, imposteranno le cose in modo da poter accedere a livello globale ai valori passati dei flag che hai definito, cometf.flags.FLAGS.batch_size
da qualsiasi punto in cui ti serva nel tuo codice.