Problema con l'esecuzione di Object_detection_tutorial TypeError: load () mancante 2 argomenti posizionali richiesti


11

Sono abbastanza nuovo su Tensorflow e sto cercando di eseguire object_detection_tutorial. Ricevo TypeErrror e non so come risolverlo.

Questa è la funzione load_model che manca di 2 argomenti:

tags: set di tag stringa per identificare il MetaGraphDef richiesto. Questi dovrebbero corrispondere ai tag utilizzati quando si salvano le variabili utilizzando l'API save () SavedModel.

export_dir: directory in cui si trovano il buffer del protocollo SavedModel e le variabili da caricare.

def load_model(model_name):
  base_url = 'http://download.tensorflow.org/models/object_detection/'
  model_file = model_name + '.tar.gz'
  model_dir = tf.keras.utils.get_file(
    fname=model_name, 
    origin=base_url + model_file,
    untar=True)

  model_dir = pathlib.Path(model_dir)/"saved_model"

  model = tf.saved_model.load(str(model_dir))
  model = model.signatures['serving_default']

  return model
WARNING:tensorflow:From <ipython-input-9-f8a3c92a04a4>:11: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-e10c73a22cc9> in <module>
      1 model_name = 'ssd_mobilenet_v1_coco_2017_11_17'
----> 2 detection_model = load_model(model_name)

<ipython-input-9-f8a3c92a04a4> in load_model(model_name)
      9   model_dir = pathlib.Path(model_dir)/"saved_model"
     10 
---> 11   model = tf.saved_model.load(str(model_dir))
     12   model = model.signatures['serving_default']
     13 

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',

TypeError: load() missing 2 required positional arguments: 'tags' and 'export_dir'

Potete aiutarmi a risolvere questo problema ed eseguire il mio primo rilevatore di oggetti: D?

Risposte:


14

Ho avuto lo stesso problema e sto cercando di risolverlo da 1 settimana. Immagino che la soluzione dovrebbe essere questa;

model = tf.compat.v2.saved_model.load(str(model_dir), None)

Maggiori dettagli sarebbero (dal sito ufficiale );

Carica un SavedModel da export_dir.

tf.saved_model.load(
    export_dir,
    tags=None
)

alias:

tf.compat.v1.saved_model.load_v2

tf.compat.v2.saved_model.load

1
Ho usato la tua soluzione e ho avuto un altro errore. Ho aggiornato tutto quello che potevo e funziona! Ho anche avuto un errore con pathlib non installato.
Dominik,

@Dominik puoi essere più specifico? forse posso aiutare perché questa avventura tensorflow mi ha portato a risolvere molti problemi: D
Onur Baskin

4
@OnurBaskin In seguito si verifica un errore: l'argomento TypeError: int () deve essere una stringa, un oggetto simile a un byte o un numero, non "Tensore"
kaitsu

@Dominik Presumo sia la tua versione di Tensorflow. Dovrebbe essere la versione 2.0 (stabile). Ecco il link per la domanda che ho posto, forse stai riscontrando l'errore esatto. Inoltre, cerca tutte le vecchie importazioni che richiedono "compat.v1". in seguito dovresti avere molti più errori ma è così che migra un vecchio codice.
Onur Baskin,

@OnurBaskin Sono abbastanza confuso. Pensavo che l'API di rilevamento oggetti fosse compatibile solo con le versioni TensorFlow 1.
Biiiiiird,

0

Immaginavo che fosse un problema di filiale e l'utilizzo del ramo tf_2_1_reference ha fatto il trucco per me:

igian@iGians-MBP models % git checkout tf_2_1_reference
M   research/object_detection/object_detection_tutorial.ipynb
Branch 'tf_2_1_reference' set up to track remote branch 'tf_2_1_reference' from 'origin'.
Switched to a new branch 'tf_2_1_reference'
igians@iGians-MBP models % jupyter notebook

Quindi esegui ogni cella di Giove del tutorial come un buon novizio!

Questo è il ramo che ho usato: https://github.com/tensorflow/models/tree/tf_2_1_reference

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.