Ошибка - ввод, который не является символьным тензором - логический

Я пытаюсь отрегулировать веса слоя Dense из-за потери бинарной кроссэнтропии. A создали общий слой, который выводит для двух векторов два значения (encoded_value_1 и encoded_value_2). Я хочу создать логическое значение, равное 1, если значение encoded_value_1 превосходит значение encoded_value_2. Для этого я использую greater через слой Lambda. Тем не менее, это вызывает ошибку (см. Ниже).

import keras
from keras.backend import greater
from keras.layers import Input, LSTM, Dense, Lambda, concatenate
from keras.models import Model

value_1 = Input(shape=(4,))
value_2 = Input(shape=(4,))

shared_layer = Dense(1)
encoded_value_1 = shared_layer(value_1)
encoded_value_2 = shared_layer(value_2)

x = Lambda(greater,output_shape=(1,))((encoded_value_1,encoded_value_2)) 
model = Model(inputs=[value_1, value_2], outputs=x)
model.compile(optimizer='adam',loss='binary_crossentropy', metrics='accuracy'])

NB: Я также пытался объединить два слоя, у меня была такая же ошибка.

merged_vector = concatenate([encoded_value_1, encoded_value_2], axis=-1)
x = Lambda(greater,output_shape=(1,))((merged_vector[0],merged_vector[1]))

ValueError: слой lambda_4 был вызван с входом, который не является символьным тензором. Полученный тип:. Полный ввод: [(,)]. Все входы в слой должны быть тензорами.


person Théo Simier    schedule 27.05.2019    source источник


Ответы (1)


Есть три пункта:

  1. Когда слой Lambda имеет более одного входа, они должны передаваться как список тензоров, а не кортеж.

  2. Результат greater - это логический тензор, который вам нужно преобразовать в float для выполнения вычислений.

  3. greater принимает два входа, поэтому вам нужно обернуть его внутри функции lambda python, чтобы иметь возможность использовать его в Lambda слое в Keras.

Следовательно, у нас было бы:

from keras import backend as K

# ...
x = Lambda(lambda z: K.cast(K.greater(z[0], z[1]), K.floatx()),output_shape=(1,))([encoded_value_1,encoded_value_2])

А также не забудьте об отсутствующей открывающей скобке для аргумента metrics:

..., metrics=['accuracy'])
             ^
             |
             |
          missing!
person today    schedule 27.05.2019