Вкратце, я создал конвейер ввода данных с помощью tenorflow Dataset API. Затем я реализовал модель CNN для классификации с использованием keras, которую преобразовал в оценщик. Я накормил свой оценщик Train and Eval Specs своим input_fn, предоставив входные данные для обучения и оценки. И в качестве последнего шага я начал обучение модели с tf.estimator.train_and_evaluate
def my_input_fn(tfrecords_path):
dataset = (...)
return batch_fbanks, batch_labels
def build_model():
model = tf.keras.models.Sequential()
model.add(...)
model.compile(...)
return model
model = build_model()
run_config=tf.estimator.RunConfig(model_dir,save_summary_steps=100,save_checkpoints_steps=1000)
estimator = tf.keras.estimator.model_to_estimator(model,config=run_config)
def serving_input_receiver_fn():
inputs = {'Conv1_input': tf.compat.v1.placeholder(shape=[None, 11,120,1], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn, name="best_exporter", exports_to_keep=5)
train_spec_dnn = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn(train_data_path),hooks=[hook])
eval_spec_dnn = tf.estimator.EvalSpec(input_fn = lambda: my_eval_input_fn(eval_data_path),exporters=exporter,start_delay_secs=0,throttle_secs=15)
tf.estimator.train_and_evaluate(estimator, train_spec_dnn, eval_spec_dnn)
Я сохраняю 5 лучших контрольных точек, используя tf.estimator.BestExporter
, как показано выше. После завершения обучения я хочу перезагрузить лучшую модель и преобразовать ее в оценщик, чтобы повторно оценить модель и спрогнозировать новый набор данных. Однако моя проблема заключается в восстановлении контрольной точки для оценщика. Я пробовал несколько решений, но каждый раз, когда я не получаю объект оценки, мне нужно запускать его методы evaluate
и predict
.
Чтобы указать больше, каждая из лучших директорий контрольных точек организована следующим образом:
./
variables/
variables.data-00000-of-00002
variables.data-00001-of-00002
variables.index
saved_model.pb
Итак, вопрос в том, как я могу получить объект оценки с лучшей контрольной точки, чтобы я мог использовать его для оценки моей модели и прогнозирования новых данных?
Примечание. Я нашел несколько предлагаемых решений, основанных на функциях TensorFlow v1, которые не могут решить мою проблему, потому что я работаю с TF v2.
Большое спасибо, любая помощь приветствуется.