Come funziona tf.app.run ()?


149

Come tf.app.run()funziona la traduzione demo di Tensorflow?

In tensorflow/models/rnn/translate/translate.py, c'è una chiamata a tf.app.run(). Come viene gestito?

if __name__ == "__main__":
    tf.app.run() 

Risposte:


135
if __name__ == "__main__":

significa che il file corrente viene eseguito sotto una shell anziché importato come modulo.

tf.app.run()

Come puoi vedere attraverso il file app.py

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

Interrompiamo riga per riga:

flags_passthrough = f._parse_flags(args=args)

Ciò garantisce che l'argomento che passi attraverso la riga di comando sia valido, ad es. python my_model.py --data_dir='...' --max_iteration=10000In realtà, questa funzione è implementata in base al argparsemodulo standard di Python .

main = main or sys.modules['__main__'].main

Il primo maina destra di =è il primo argomento della funzione corrente run(main=None, argv=None) . Mentre sys.modules['__main__']significa file corrente in esecuzione (ad esempio my_model.py).

Quindi ci sono due casi:

  1. Non hai una mainfunzione in my_model.pyQuindi devi chiamaretf.app.run(my_main_running_function)

  2. hai una mainfunzione in my_model.py. (Questo è principalmente il caso.)

Ultima linea:

sys.exit(main(sys.argv[:1] + flags_passthrough))

assicura che la funzione main(argv)o my_main_running_function(argv)venga chiamata correttamente con argomenti analizzati.


67
Un pezzo mancante del puzzle per gli utenti principianti di Tensorflow: Tensorflow ha un meccanismo di gestione delle bandiere a riga di comando incorporato. Puoi definire i tuoi flag come tf.flags.DEFINE_integer('batch_size', 128, 'Number of images to process in a batch.')e quindi, se li usi tf.app.run(), imposteranno le cose in modo da poter accedere a livello globale ai valori passati dei flag che hai definito, come tf.flags.FLAGS.batch_sizeda qualsiasi punto in cui ti serva nel tuo codice.
Isarandi,

1
Questa è la risposta migliore delle (attuali) tre secondo me. Spiega "Come funziona tf.app.run ()", mentre le altre due risposte dicono semplicemente cosa fa.
Thomas Fauskanger,

Sembra che le bandiere siano gestite da abseilcui TF deve aver assorbito abseil.io/docs/python/guides/flags
CpILL

75

È solo un wrapper molto veloce che gestisce l'analisi delle bandiere e quindi le invia al tuo principale. Vedere il codice .


12
cosa significa "gestisce l'analisi delle bandiere"? Forse potresti aggiungere un link per informare i principianti cosa significa?
Pinocchio,

4
Analizza gli argomenti della riga di comando forniti al programma utilizzando il pacchetto flags. (che utilizza la libreria standard "argparse" sotto le copertine, con alcuni wrapper). È collegato dal codice a cui mi sono collegato nella mia risposta.
DGA

1
In app.py, cosa significano main = main or sys.modules['__main__'].maine cosa sys.exit(main(sys.argv[:1] + flags_passthrough))significano?
hAcKnRoCk,

3
questo mi sembra strano, perché avvolgere la funzione principale in tutto ciò se si può semplicemente chiamarla direttamente main()?
Charlie Parker,

2
hAcKnRoCk: se non ci sono file principali nel file, utilizza invece tutto ciò che è in sys.modules [' main '] .main. Sys.exit significa eseguire il comando principale così trovato usando args e tutti i flag passati, e uscire con il valore di ritorno di main. @CharlieParker - per la compatibilità con le librerie di app python esistenti di Google come gflags e google-apputils. Vedi, ad esempio, github.com/google/google-apputils
dga

8

Non c'è niente di speciale in tf.app. Questo è solo uno script di entry point generico , che

Esegue il programma con una funzione 'principale' opzionale e un elenco 'argv'.

Non ha nulla a che fare con le reti neurali e chiama semplicemente la funzione principale, passando attraverso qualsiasi argomento.


5

In termini semplici, il compito di tf.app.run()è innanzitutto impostare i flag globali per un utilizzo successivo come:

from tensorflow.python.platform import flags
f = flags.FLAGS

e quindi esegui la tua funzione principale personalizzata con una serie di argomenti.

Ad esempio, nella base di codici NMT TensorFlow , il primo punto di ingresso per l'esecuzione del programma per addestramento / inferenza inizia a questo punto (vedere il codice seguente)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Dopo aver analizzato gli argomenti usando argparse, tf.app.run()esegui con te la funzione "main" che è definita come:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

Quindi, dopo aver impostato i flag per l'uso globale, tf.app.run()esegue semplicemente quella mainfunzione che gli si passa argvcome parametri.

PS: Come dice la risposta di Salvador Dalì , è solo una buona pratica di ingegneria del software, immagino, anche se non sono sicuro che TensorFlow esegua una corsa ottimizzata della mainfunzione rispetto a quella eseguita con il normale CPython.


2

Il codice di Google dipende molto dal fatto che i flag globali accedono alle librerie / binari / script python e quindi tf.app.run () analizza quei flag per creare uno stato globale nella variabile FLAG (o qualcosa di simile) e quindi chiama python main ( ) come dovrebbe.

Se non avessero ricevuto questa chiamata a tf.app.run (), gli utenti potrebbero dimenticare di eseguire l'analisi dei FLAG, portando a queste librerie / binari / script che non hanno accesso ai FLAG di cui hanno bisogno.


1

2.0 Risposta compatibile : se si desidera utilizzare tf.app.run()in Tensorflow 2.0, dovremmo usare il comando,

tf.compat.v1.app.run()oppure puoi usare tf_upgrade_v2per convertire il 1.xcodice in 2.0.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.