def show_batch () не показывает изображения моих поездов

Я выполняю предварительную обработку данных в TensorFlow, следуя инструкциям на их веб-сайте: https://www.tensorflow.org/tutorials/load_data/images

Однако после преобразования изображений в тензоры и присвоения каждому из них соответствующей метки я не могу построить их.

Параллельно загружаю пару (изображение, метку): labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE). Затем я проверяю форму изображения и соответствующую метку:

for image, label in labeled_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

И получите Image shape: (80, 80, 3) Label: [False False True False], как ожидалось.

Затем я определяю следующую функцию для пакетной подготовки набора данных для обучения:

def prepare_for_training(ds, cache=True, shuffle_buffer_size=1000):
  ds = ds.batch(100)
  ds = ds.prefetch(buffer_size=AUTOTUNE)

  if cache:
    if isinstance(cache, str):
      ds = ds.cache(cache)
    else:
      ds = ds.cache()

  ds = ds.shuffle(buffer_size=shuffle_buffer_size)
  ds = ds.repeat()

  return ds



train_ds = prepare_for_training(labeled_ds)
image_batch, label_batch = next(iter(train_ds))

Но когда я хочу отобразить каждое изображение с его меткой, используя plt.show(), изображения не отображаются. Вот как я это делаю:

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')
        return plt.show()


show_batch(image_batch.numpy(), label_batch.numpy())

Есть какие-нибудь подсказки о том, почему мои изображения могут не отображаться?


person Liz    schedule 08.02.2020    source источник


Ответы (2)



return plt.show находился внутри цикла for. Это должно быть снаружи вот так:

def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')

    return plt.show()
person Liz    schedule 08.02.2020