Пакетная нормализация в пользовательском оценщике в Tensorflow

Я имею в виду примечание по адресу tf.layers.batch_normilization:

Примечание: при обучении необходимо обновить moving_mean и moving_variance. По умолчанию операции обновления помещаются в tf.GraphKeys.UPDATE_OPS, поэтому их необходимо добавить как зависимость к train_op. Например:

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)

Как можно реализовать это в пользовательском оценщике? Например, посмотрите на этот пример на веб-сайте Tensorflow: Полная модель морского ушка


person Jan Krynauw    schedule 25.07.2017    source источник


Ответы (2)


Я думаю, вы можете передать train_op, на который вы ссылаетесь, в параметре train_op EstimatorSpec.

person user1566912    schedule 21.03.2018

По следующей проблеме в самом низу у вас есть пример https://github.com/tensorflow/tensorflow/issues/16455

if mode == tf.estimator.ModeKeys.TRAIN:
    lr = 0.001
    optimizer = tf.train.RMSPropOptimizer(learning_rate=lr, decay=0.9)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op)
person Marc Moreaux    schedule 19.06.2018