Неправильный порядок меток после пакетной обработки при импорте данных из tfrecords

У меня проблема при импорте данных из файлов tfrecords. Каждый образец в tfrecords состоит из вектора признаков длиной 100 и вектора горячих меток длиной 13. Я использую приведенный ниже код для импорта данных из tfrecords, ссылаясь на официальное руководство https://www..tensorflow.org/programmers_guide/datasets

def read_data(examples):
    features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
               "label": tf.FixedLenFeature([category], tf.int64)}
    parsed_features = tf.parse_single_example(examples, features)
    return parsed_features['features'], parsed_features['label']

# get next batch of data and label
def next_batch(filename, batch_size):
    data = tf.data.TFRecordDataset(filename)
    data = data.map(read_data)
    data = data.batch(batch_size)
    iterator = data.make_one_shot_iterator()
    next_data, next_label = iterator.get_next()
    return next_data, next_label

with tf.Session() as sess:
    filetrain = 'train.tfrecords'
    next_data, next_label = next_batch(filetrain, num_example_train)
    sess.run(tf.global_variables_initializer())

    data = sess.run(next_data)
    label = sess.run(next_label)

Проблема в том, что после пакетирования порядок меток становится неправильным. А если убрать код 'data=data.batch', то все ок.

Я думаю, что одна из возможных причин заключается в том, что функции и метки группируются независимо друг от друга. Поэтому я попытался разобрать пример после пакетной обработки, но получил ошибку «Сериализованный ввод должен быть скалярным». Пожалуйста, помогите мне, если вы знаете, как справиться с этой проблемой, большое спасибо!


person Yuxiao Xu    schedule 18.12.2017    source источник


Ответы (1)


Я уверен, что это дубликат, но я не могу найти другой вопрос, поэтому отвечу здесь.

Ваша проблема заключается в том, что вы дважды вызываете sess.run() для данных и меток. Каждый раз, когда вы вызываете sess.run, ваш график оценивается (т. е. извлекается новая партия и выполняется через график до тех пор, пока все значения тензоров в списке, который вы передаете поскольку известен первый аргумент run).

При этом ваши data и label относятся к двум разным партиям (отсюда и тот факт, что они выглядят неправильно).

Вам нужно получить их в одном вызове с:

data, label = sess.run([next_data, next_label])
person GPhilo    schedule 18.12.2017
comment
Вот именно проблема! Спасибо за понятное объяснение! - person Yuxiao Xu; 19.12.2017
comment
Пожалуйста. Пожалуйста, отметьте ответ как решенный, чтобы я мог ссылаться на него для будущих дубликатов :) - person GPhilo; 19.12.2017