У меня проблема при импорте данных из файлов tfrecords. Каждый образец в tfrecords состоит из вектора признаков длиной 100 и вектора горячих меток длиной 13. Я использую приведенный ниже код для импорта данных из tfrecords, ссылаясь на официальное руководство https://www..tensorflow.org/programmers_guide/datasets
def read_data(examples):
features = {"features": tf.FixedLenFeature([seq_len], tf.int64),
"label": tf.FixedLenFeature([category], tf.int64)}
parsed_features = tf.parse_single_example(examples, features)
return parsed_features['features'], parsed_features['label']
# get next batch of data and label
def next_batch(filename, batch_size):
data = tf.data.TFRecordDataset(filename)
data = data.map(read_data)
data = data.batch(batch_size)
iterator = data.make_one_shot_iterator()
next_data, next_label = iterator.get_next()
return next_data, next_label
with tf.Session() as sess:
filetrain = 'train.tfrecords'
next_data, next_label = next_batch(filetrain, num_example_train)
sess.run(tf.global_variables_initializer())
data = sess.run(next_data)
label = sess.run(next_label)
Проблема в том, что после пакетирования порядок меток становится неправильным. А если убрать код 'data=data.batch', то все ок.
Я думаю, что одна из возможных причин заключается в том, что функции и метки группируются независимо друг от друга. Поэтому я попытался разобрать пример после пакетной обработки, но получил ошибку «Сериализованный ввод должен быть скалярным». Пожалуйста, помогите мне, если вы знаете, как справиться с этой проблемой, большое спасибо!