Как умножить tf.keras.layers
на tf.Variable
?
Контекст: я создаю сверточный фильтр, зависящий от выборки, который состоит из общего фильтра W
, который преобразуется посредством сдвига и масштабирования, зависящего от выборки. Таким образом, сверточный исходный фильтр W
преобразуется в aW + b
, где a
— масштабирование, зависящее от выборки, а b
— сдвиг, зависящий от выборки. Одним из применений этого является обучение автоэнкодера, где образцом зависимости является метка, поэтому каждая метка сдвигает/масштабирует сверточный фильтр. Из-за сверток, зависящих от выборки/метки, я использую tf.nn.conv2d
, который принимает фактические фильтры в качестве входных данных (а не только количество/размер фильтров), и лямбда-слой с tf.map_fn
для применения другого преобразованного фильтра (на основе метки) для каждый образец. Хотя детали различаются, такой подход к свертке, зависящий от выборки, обсуждается в этом посте: -sample-in-the-mini-batch">Tensorflow: свертки с разными фильтрами для каждого образца в мини-пакете.
Вот что я думаю:
input_img = keras.Input(shape=(28, 28, 1))
label = keras.Input(shape=(10,)) # number of classes
num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)
# filter is of shape (filter_h, filter_w, input channels, output filters)
filter = tf.Variable(tf.ones((3,3,input_img.shape[-1],num_filters)))
# TODO: need to shift and scale -> shift*(filter) + scale along each output filter dimension (32 filter dimensions)
Я не уверен, как реализовать часть TODO
. Я думал о tf.keras.layers.Multiply()
для масштабирования и tf.keras.layers.Add()
для сдвига, но, насколько мне известно, они не работают с tf.Variable. Как мне обойти это? Предполагая, что вещание размеров/формы работает, я хотел бы сделать что-то вроде этого (примечание: вывод должен по-прежнему иметь ту же форму, что и var, и просто масштабируется по каждому из 32 размеров выходного фильтра)
output = tf.keras.layers.Multiply()([var, scale])
var
? - person thushv89   schedule 02.10.2020var
наfilter
, чтобы указать, что это сверточный фильтр (до смещения/масштабирования). Он имеет форму (filter_height, filter_width, входные каналы, выходные фильтры), которая в этом примере равна (3,3,1,32). - person Jane Sully   schedule 02.10.2020