Оценщик Tensorflow — warm_start_from и model_dir

При использовании tf.estimator с warm_start_from и model_dir и каталоге warm_start_from и каталоге model_dir есть действительные контрольные точки, какая контрольная точка будет фактически восстановлена?

Чтобы дать некоторый контекст, мой код оценки выглядит так

est = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    warm_start_from=warm_start_dir)

for epoch in range(num_epochs):
    est.train(input_fn=train_input_fn)
    est.evaluate(input_fn=eval_input_fn)

(Функции ввода используют одноразовые итераторы.)

Итак, во время первой итерации, когда model_dir пусто, я хочу, чтобы была загружена контрольная точка горячего запуска, но в следующую эпоху я хотел бы, чтобы была загружена промежуточная настроенная контрольная точка из последней итерации в model_dir. Но, по крайней мере, из журналов, похоже, что warm_start_dir все еще загружается.

Вероятно, я мог бы переопределить свой оценщик для следующих итераций, но мне интересно, не следует ли его как-то встроить в оценщик.


person mtngld    schedule 15.04.2018    source источник
comment
Estimator загрузит операции для горячего запуска весов, но перезапишет их значением, сохраненным в последней контрольной точке. Estimator должен добавить функцию проверки отсутствия сохраненных контрольных точек перед загрузкой весов warm_start.   -  person kww    schedule 21.06.2018


Ответы (1)


У меня была аналогичная проблема, я решил ее, предоставив хук инициализации, который запускается при запуске сеанса, и используя tf.estimator.train_and_evaluate (хотя я не могу присвоить себе все это решение, так как я видел что-то подобное для другого цель в другом месте):

class InitHook(tf.train.SessionRunHook):
    """initializes model from a checkpoint_path
    args:
        modelPath: full path to checkpoint
    """
    def __init__(self, checkpoint_dir):
        self.modelPath = checkpoint_dir
        self.initialized = False

    def begin(self):
        """
        Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
        """
        if not self.initialized:
            log = logging.getLogger('tensorflow')
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            if checkpoint is None:
                log.info('No pre-trained model is available, training from scratch.')
            else:
                log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
                tf.train.warm_start(checkpoint)
            self.initialized = True

Затем для обучения:

initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
    input_fn = train_input_fn,
    max_steps = N_STEPS, 
    hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
    input_fn = eval_input_fn,
    steps = None,
    name = 'eval',
    throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)

Это запускается один раз в начале для инициализации переменных из warm_start_dir. Позже, когда в оценщике model_dir появятся новые контрольные точки, он продолжит работу с них.

person kamyonet    schedule 17.10.2018
comment
Отлично, большое спасибо, решил мою проблему. - person super.single430; 21.03.2019