получить следующий тензор из набора данных в tf.while_loop

Я хочу перебирать набор данных до тех пор, пока не будет выполнено определенное условие, но я не знаю, как это сделать. Ниже мой код.

import tensorflow as tf

c = tf.constant([1,2,6])
d = tf.data.Dataset.from_tensor_slices((c,))
t = d.make_one_shot_iterator().get_next()

def condition(t):
  return t < 5

def body(t):
  # I don't know what to do here to return the next t
  return [t]

t = tf.while_loop(condition, body, [t])

with tf.Session() as sess:
    print(sess.run([t]))

В ответ на ответ Алекса ниже приведен более реалистичный пример того, чего я хочу достичь.

import tensorflow as tf

# I want to "merge" the dataset da to dataset db by "backfilling" da.
# So session.run will return [[1,'a'], [1,'x']], then [[5, 'c'],[3, 'y']]
# note that one element from dataset da is skipped, which is what I want to achieve with the while loop.
ta = tf.constant([1,2,5]) 
va = tf.constant(['a','b','c'])
da = tf.data.Dataset.from_tensor_slices((ta, va))

tb = tf.constant([1,3,6])
vb = tf.constant(['x','y','z'])
db = tf.data.Dataset.from_tensor_slices((tb, vb))

ea = da.make_one_shot_iterator().get_next()
eb = db.make_one_shot_iterator().get_next()

def condition(ea, eb):
  return ea[0] < eb[0]

def body(ea, eb):
  # I don't know what to do here to get the next ea.
  return ea, eb

result = tf.while_loop(condition, body, (ea, eb))

with tf.Session() as sess:
  sess.run([result])

Я мог бы перенести логику цикла while на python, как предложил Алекс, но я предполагаю, что оставление ее в графе потока данных будет иметь лучшую производительность.


person bill    schedule 04.05.2018    source источник


Ответы (2)


Я полагаю, вы еще не понимаете, как работает Tensorflow. Tensorflow tf.while_loop создает цикл while внутри вычислительного графа, добавляя операторы управления для повторного применения частей графа несколько раз, пока не будет выполнено определенное условие. Я бы посоветовал начать читать здесь, чтобы узнать, что такое графики и сессии.

Вы, конечно, не хотите перебирать свой набор данных внутри вычислительного графа, итерация, которую вы хотите, должна происходить в Python, а не в графе Tensorflow.

Вот как вы это сделаете:

iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

for i in range(100):
  value = sess.run(next_element)

Более подробно это объясняется здесь.

person Alexander Harnisch    schedule 04.05.2018
comment
спасибо Алекс! Я с тобой согласен. Я обновил свой вопрос, чтобы прояснить свое намерение. - person bill; 04.05.2018
comment
Я до сих пор не согласен с тем, что вы должны делать это именно так. Но в любом случае, если вы действительно хотите, это не сработает, потому что make_one_shot_iterator возвращает итератор Python, а не тензор. Вы не можете перебирать его внутри графа Tensorflow, это невозможно. Если вы действительно хотите попробовать что-то подобное, вам нужно другое решение, например, предложенное Vipin Pillai. - person Alexander Harnisch; 07.05.2018
comment
@AlexanderHarnisch Я еще не тестировал это, но кажется, что это возможно: stackoverflow.com/questions/50237486/ - person Liz; 20.01.2019

Вы можете использовать метод Dataset.filter() для фильтрации элементы набора данных в соответствии с заданным пользователем предикатом. Вы должны передать функцию фильтра, которая возвращает тензор tf.bool, который оценивается как true, если вы хотите сохранить запись, или как false в противном случае.

person Vipin Pillai    schedule 05.05.2018