Определяне дали дадена стойност е в набор в TensorFlow

Функциите tf.logical_or, tf.logical_and и tf.select са много полезни.

Да предположим обаче, че имате стойност x и искате да видите дали е в set(a, b, c, d, e). В python просто бихте написали:

if x in set([a, b, c, d, e]):
  # Do some action.

Доколкото мога да преценя, единственият начин да направите това в TensorFlow е да вложите 'tf.logical_or' заедно с 'tf.equal'. Предоставих само една итерация на тази концепция по-долу:

tf.logical_or(
    tf.logical_or(tf.equal(x, a), tf.equal(x, b)),
    tf.logical_or(tf.equal(x, c), tf.equal(x, d))
)

Чувствам, че трябва да има по-лесен начин да направите това в TensorFlow. Е там?


person LeavesBreathe    schedule 05.01.2016    source източник


Отговори (3)


Разгледайте този свързан въпрос: Преброяване на броя стойности True в булев Тензор

Трябва да можете да изградите тензор, състоящ се от [a, b, c, d, e] и след това да проверите дали някой от редовете е равен на x, като използвате tf.equal(.)

person Rafał Józefowicz    schedule 05.01.2016
comment
Благодаря за прозрението. Reduce_sum е най-добрият начин. - person LeavesBreathe; 05.01.2016
comment
Можете също да използвате tf.listdiff, за да постигнете същото. - person dga; 05.01.2016
comment
@dga, който показва само разликата, а не подобните? - person user3352632; 23.01.2021
comment
За някой нов в Tensorflow е трудно да види как вашата свързана публикация се отнася към ситуацията на OP. Може би бихте могли да публикувате пълен кодов фрагмент тук? - person Addison Klinke; 12.02.2021

За да предоставите по-конкретен отговор, кажете, че искате да проверите дали последното измерение на тензора x съдържа някаква стойност от 1D тензор s, можете да направите следното:

tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))

Например за s и x:

s = tf.constant([3, 4])
x = tf.constant([[[1, 2, 3, 0, 0], 
                  [4, 4, 4, 0, 0]], 
                 [[3, 5, 5, 6, 4], 
                  [4, 7, 3, 8, 9]]])

x има форма [2, 2, 5], а s има форма [2], така че tile_multiples = [1, 1, 1, 2], което означава, че ще подредим последното измерение на x 2 пъти (по веднъж за всеки елемент в s) по ново измерение. И така, x_tile ще изглежда така:

[[[[1 1]
   [2 2]
   [3 3]
   [0 0]
   [0 0]]

  [[4 4]
   [4 4]
   [4 4]
   [0 0]
   [0 0]]]

 [[[3 3]
   [5 5]
   [5 5]
   [6 6]
   [4 4]]

  [[4 4]
   [7 7]
   [3 3]
   [8 8]
   [9 9]]]]

и x_in_s ще сравни всяка от подредените стойности с една от стойностите в s. tf.reduce_any по дължината на последния дим ще върне истина, ако някоя от стойностите на плочките е била в s, давайки крайния резултат:

[[[False False  True False False]
  [ True  True  True False False]]

 [[ True False False False  True]
  [ True False  True False False]]]
person Emma Strubell    schedule 05.01.2018

Ето две решения, искаме да проверим дали query е в whitelist

whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
query = "RESTAURANT"

#use broadcasting for element-wise tensor operation
broadcast_equal = tf.equal(whitelist, query)

#method 1: using tensor ops
broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
broadcast_sum = tf.reduce_sum(broadcast_equal_int)

#method 2: using some tf.core API
nz_cnt = tf.count_nonzero(broadcast_equal)

sess.run([broadcast_equal, broadcast_sum, nz_cnt])
#=> [array([False, False,  True, False]), 1, 1]

Така че, ако изходът е > 0, тогава елементът е в комплекта.

person eggie5    schedule 20.10.2019
comment
Как работи това, когато и query, и whitelist имат повече от един елемент? - person Addison Klinke; 12.02.2021