Как использовать Keras Multiply() с tf.Variable?

Как умножить tf.keras.layers на tf.Variable?

Контекст: я создаю сверточный фильтр, зависящий от выборки, который состоит из общего фильтра W, который преобразуется посредством сдвига и масштабирования, зависящего от выборки. Таким образом, сверточный исходный фильтр W преобразуется в aW + b, где a — масштабирование, зависящее от выборки, а b — сдвиг, зависящий от выборки. Одним из применений этого является обучение автоэнкодера, где образцом зависимости является метка, поэтому каждая метка сдвигает/масштабирует сверточный фильтр. Из-за сверток, зависящих от выборки/метки, я использую tf.nn.conv2d, который принимает фактические фильтры в качестве входных данных (а не только количество/размер фильтров), и лямбда-слой с tf.map_fn для применения другого преобразованного фильтра (на основе метки) для каждый образец. Хотя детали различаются, такой подход к свертке, зависящий от выборки, обсуждается в этом посте: -sample-in-the-mini-batch">Tensorflow: свертки с разными фильтрами для каждого образца в мини-пакете.

Вот что я думаю:

input_img = keras.Input(shape=(28, 28, 1))  
label = keras.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)

# filter is of shape (filter_h, filter_w, input channels, output filters)
filter = tf.Variable(tf.ones((3,3,input_img.shape[-1],num_filters)))
# TODO: need to shift and scale -> shift*(filter) + scale along each output filter dimension (32 filter dimensions)

Я не уверен, как реализовать часть TODO. Я думал о tf.keras.layers.Multiply() для масштабирования и tf.keras.layers.Add() для сдвига, но, насколько мне известно, они не работают с tf.Variable. Как мне обойти это? Предполагая, что вещание размеров/формы работает, я хотел бы сделать что-то вроде этого (примечание: вывод должен по-прежнему иметь ту же форму, что и var, и просто масштабируется по каждому из 32 размеров выходного фильтра)

output = tf.keras.layers.Multiply()([var, scale]) 

person Jane Sully    schedule 02.10.2020    source источник
comment
Можете ли вы объяснить форму вашего var?   -  person thushv89    schedule 02.10.2020
comment
Извините за отсутствие ясности. Я изменил var на filter, чтобы указать, что это сверточный фильтр (до смещения/масштабирования). Он имеет форму (filter_height, filter_width, входные каналы, выходные фильтры), которая в этом примере равна (3,3,1,32).   -  person Jane Sully    schedule 02.10.2020


Ответы (1)


Это требует некоторой работы и нуждается в пользовательском слое. Например, вы не можете использовать tf.Variable с tf.keras.Lambda

class ConvNorm(layers.Layer):
    def __init__(self, height, width, n_filters):
        super(ConvNorm, self).__init__()
        self.height = height  
        self.width = width
        self.n_filters = n_filters

    def build(self, input_shape):              
        self.filter = self.add_weight(shape=(self.height, self.width, input_shape[-1], self.n_filters),
                                 initializer='glorot_uniform',
                                 trainable=True)        
        # TODO: Add bias too


    def call(self, x, scale, shift):
        shift_reshaped = tf.expand_dims(tf.expand_dims(shift,1),1)
        scale_reshaped = tf.expand_dims(tf.expand_dims(scale,1),1)

        norm_conv_out = tf.nn.conv2d(x, self.filter*scale + shift, strides=(1,1,1,1), padding='SAME')
                
        return norm_conv_out

Использование слоя

import tensorflow as tf
import tensorflow.keras.layers as layers

input_img = layers.Input(shape=(28, 28, 1))  
label = layers.Input(shape=(10,)) # number of classes

num_filters = 32
shift = layers.Dense(num_filters, activation=None, name='shift')(label) # (32,)
scale = layers.Dense(num_filters, activation=None, name='scale')(label) # (32,)

conv_norm_out = ConvNorm(3,3,32)(input_img, scale, shift)
print(norm_conv_out.shape)

Примечание. Обратите внимание, что я не добавил смещения. Вам также понадобится смещение для слоя свертки. Но это прямолинейно.

person thushv89    schedule 02.10.2020
comment
Большое спасибо за этот ответ. Я определенно не смог бы реализовать пользовательский слой самостоятельно, поэтому я благодарен за помощь. У меня есть несколько дополнительных вопросов, которые могут быть очевидными. 1) scale_reshaped и shift_reshaped случайно дважды включены в call()? 2) Я не думаю, что scale_reshaped и shift_reshaped используются для модификации фильтра, но я думаю, что они должны? Это задумано/правильно ли я понимаю? 3) Почему ты делаешь *x[2]+x[3] для tf.nn.conv2d? Спасибо! - person Jane Sully; 02.10.2020
comment
@JaneSully все, что вы сказали, верно. Наверное, я был слишком неосторожен. Исправлены проблемы - person thushv89; 02.10.2020