Tensorflow non può ottenere `image.shape` dal metodo in` dataset.map (mapFn) `


10

Sto cercando di fare l' tensorflowequivalente di torch.transforms.Resize(TRAIN_IMAGE_SIZE), a cui ridimensiona la dimensione dell'immagine più piccolaTRAIN_IMAGE_SIZE . Qualcosa come questo

def transforms(filename):
  parts = tf.strings.split(filename, '/')
  label = parts[-2]

  image = tf.io.read_file(filename)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  # this doesn't work with Dataset.map() because image.shape=(None,None,3) from Dataset.map()
  image = largest_sq_crop(image) 

  image = tf.image.resize(image, (256,256))
  return image, label

list_ds = tf.data.Dataset.list_files('{}/*/*'.format(DATASET_PATH))
images_ds = list_ds.map(transforms).batch(4)

La semplice risposta è qui: Tensorflow: ritaglia la più grande regione quadrata centrale dell'immagine

Ma quando uso il metodo con tf.data.Dataset.map(transforms), ottengo shape=(None,None,3)dall'interno largest_sq_crop(image). Il metodo funziona bene quando lo chiamo normalmente.


1
Credo che il problema abbia a che fare con il fatto che EagerTensorsnon sono disponibili all'interno, Dataset.map()quindi la forma è sconosciuta. c'è una soluzione?
michael

Puoi includere la definizione di largest_sq_crop?
Jakub

Risposte:


1

Ho trovato la risposta Ha a che fare con il fatto che il mio metodo di ridimensionamento ha funzionato bene con un'esecuzione entusiasta, ad es. tf.executing_eagerly()==TrueMa fallito quando usato all'interno dataset.map(). A quanto pare, in questo ambiente di esecuzione, tf.executing_eagerly()==False.

Il mio errore era nel modo in cui stavo disimballando la forma dell'immagine per ottenere le dimensioni per il ridimensionamento. L'esecuzione del grafico Tensorflow non sembra supportare l'accesso atensor.shape tupla.

  # wrong
  b,h,w,c = img.shape
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # also wrong
  b = img.shape[0]
  h = img.shape[1]
  w = img.shape[2]
  c = img.shape[3]
  print("ERR> ", h,w,c)
  # ERR>  None None 3

  # but this works!!!
  shape = tf.shape(img)
  b = shape[0]
  h = shape[1]
  w = shape[2]
  c = shape[3]
  img = tf.reshape( img, (-1,h,w,c))
  print("OK> ", h,w,c)
  # OK>  Tensor("strided_slice_2:0", shape=(), dtype=int32) Tensor("strided_slice_3:0", shape=(), dtype=int32) Tensor("strided_slice_4:0", shape=(), dtype=int32)

Stavo usando le dimensioni della forma a valle nella mia dataset.map()funzione e ha generato la seguente eccezione perché stava ottenendo Noneinvece di un valore.

TypeError: Failed to convert object of type <class 'tuple'> to Tensor. Contents: (-1, None, None, 3). Consider casting elements to a supported type.

Quando sono passato a decomprimere manualmente la forma da tf.shape(), tutto ha funzionato bene.

Utilizzando il nostro sito, riconosci di aver letto e compreso le nostre Informativa sui cookie e Informativa sulla privacy.
Licensed under cc by-sa 3.0 with attribution required.