как сохранять, восстанавливать, делать прогнозы с сиамской сетью (с потерей триплетов)

Я пытаюсь разработать сиамскую сеть для простой проверки лица (и распознавания на втором этапе). У меня есть сеть, которую мне удалось обучить, но я немного озадачен тем, как сохранить и восстановить модель + делать прогнозы с помощью обученной модели. Надеясь, что, возможно, опытный человек в этой области поможет добиться прогресса.

Вот как я создаю свою сиамскую сеть, для начала...

model = ResNet50(weights='imagenet')   # get the original ResNet50 model
model.layers.pop()   # Remove the last layer
for layer in model.layers:
    layer.trainable = False   # do not train any of original layers

x = model.get_layer('flatten_1').output
model_out = Dense(128, activation='relu',  name='model_out')(x)
model_out = Lambda(lambda  x: K.l2_normalize(x,axis=-1))(model_out)
new_model = Model(inputs=model.input, outputs=model_out)

# At this point, a new layer (with 128 units) added and normalization applied.

# Now create siamese network on top of this

anchor_in = Input(shape=(224, 224, 3))
positive_in = Input(shape=(224, 224, 3))
negative_in = Input(shape=(224, 224, 3))

anchor_out = new_model(anchor_in)
positive_out = new_model(positive_in)
negative_out = new_model(negative_in)

merged_vector = concatenate([anchor_out, positive_out, negative_out], axis=-1)

# Define the trainable model
siamese_model = Model(inputs=[anchor_in, positive_in, negative_in],
                      outputs=merged_vector)
siamese_model.compile(optimizer=Adam(lr=.0001), 
                      loss=triplet_loss, 
                      metrics=[dist_between_anchor_positive,
                               dist_between_anchor_negative])

А я тренирую сиамскую_модель. Когда я его тренирую, если я правильно интерпретирую результаты, это на самом деле не тренирует базовую модель, а просто тренирует новую сиамскую сеть (по сути, тренируется только последний слой).

Но эта модель имеет 3 входных потока. После обучения мне нужно сохранить эту модель таким образом, чтобы она принимала только 1 или 2 входа, чтобы я мог выполнять прогнозы, вычисляя расстояние между двумя заданными изображениями. Как мне сохранить эту модель и использовать ее повторно сейчас?

Заранее спасибо!

ДОПОЛНЕНИЕ:

Если вам интересно, вот краткое описание сиамской модели.

siamese_model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_2 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            (None, 224, 224, 3)  0                                            
__________________________________________________________________________________________________
model_1 (Model)                 (None, 128)          23849984    input_2[0][0]                    
                                                                 input_3[0][0]                    
                                                                 input_4[0][0]                    
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 384)          0           model_1[1][0]                    
                                                                 model_1[2][0]                    
                                                                 model_1[3][0]                    
==================================================================================================
Total params: 23,849,984
Trainable params: 262,272
Non-trainable params: 23,587,712
__________________________________________________________________________________________________



Ответы (1)


Вы можете использовать приведенный ниже код, чтобы сохранить свою модель siamese_model.save_weights(MODEL_WEIGHTS_FILE)

А затем, чтобы загрузить вашу модель, вам нужно использовать siamese_model.load_weights(MODEL_WEIGHTS_FILE)

Спасибо

person Gazal    schedule 05.07.2018
comment
Как мы можем предсказать невидимые данные, используя эти загруженные веса? @Газаль - person Lakwin Chandula; 16.07.2020
comment
@LakwinChandula - Вы получили ответ на свой вопрос? - person Jithin P James; 17.05.2021