Модели TensorFlow широко используются в приложениях машинного обучения, и их развертывание на устройствах с ограниченными ресурсами часто требует их преобразования в формат TensorFlow Lite (.tflite). В этом руководстве рассматриваются различные подходы к преобразованию модели TensorFlow .pb в .tflite с учетом различных версий TensorFlow и форматов моделей. Используете ли вы TensorFlow 1.x или 2.x и находится ли ваша модель в SavedModel или GraphDef Формат, я тебя прикрою.

Подходы к конверсии

Я расскажу вам о четырех различных подходах, каждый из которых нацелен на определенную версию TensorFlow и формат модели.

Версия 1: Использование TensorFlow 1.x с форматом SavedModel

В этом подходе я использую инфраструктуру TensorFlow 1.x для загрузки модели, хранящейся в формате SavedModel. Фрагмент кода демонстрирует, как преобразовать его в .tflite с помощью конвертера TensorFlow Lite.

import tensorflow as tf

model = tf.compat.v1.saved_model.load("path/to/saved_model")
converter = tf.compat.v1.lite.TFLiteConverter.from_saved_model(model)
tflite_model = converter.convert()

# Save tflite_model to disk
with open("converted_model.tflite", "wb") as f:
    f.write(tflite_model)

Версия 2: Использование TensorFlow 1.x с GraphDef (.pb)

Формат Этот метод ориентирован на TensorFlow 1.x и модели, сохраненные в формате GraphDef (.pb). Анализируя определение графа и указывая входные и выходные тензоры, мы используем преобразователь TensorFlow Lite для создания модели .tflite.

import tensorflow as tf

with tf.compat.v1.gfile.FastGFile("path/to/model.pb", "rb") as f:
    graph_def = tf.compat.v1.GraphDef()
    graph_def.ParseFromString(f.read())

converter = tf.compat.v1.lite.TFLiteConverter.from_frozen_graph(graph_def, input_arrays=["input_tensor"], output_arrays=["output_tensor"])
tflite_model = converter.convert()

# Save tflite_model to disk
with open("converted_model.tflite", "wb") as f:
    f.write(tflite_model)

Версия 3: Использование TensorFlow 2.x с форматом SavedModel

Для тех, кто использует TensorFlow 2.x, этот подход показывает, как загрузить модель из формата SavedModel и преобразовать ее в TensorFlow Lite с помощью tf.lite.TFLiteConverter.

import tensorflow as tf

model = tf.saved_model.load("path/to/saved_model")
converter = tf.lite.TFLiteConverter.from_saved_model("path/to/saved_model")
tflite_model = converter.convert()

# Save tflite_model to disk
with open("converted_model.tflite", "wb") as f:
    f.write(tflite_model)

Версия 4: Использование TensorFlow 2.x с конкретными функциями

Этот подход предназначен для пользователей TensorFlow 2.x, которые определили свою модель как tf.function. Мы расскажем, как преобразовать эту ConcreteFunction в формат .tflite, упрощая развертывание.

import tensorflow as tf

@tf.function
def model_inference(input):
    # Load and run the model here
    return output

concrete_func = model_inference.get_concrete_function(tf.TensorSpec(shape=(None, ...), dtype=tf.float32))

converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()

# Save tflite_model to disk
with open("converted_model.tflite", "wb") as f:
    f.write(tflite_model)

Заключение

Преобразование модели TensorFlow .pb в .tflite — ключевой шаг в развертывании моделей машинного обучения на периферийных устройствах. Используете ли вы TensorFlow 1.x или 2.x и находится ли ваша модель в SavedModel или GraphDef предоставленные подходы позволяют легко преобразовать ваши модели для эффективного развертывания. Следуя примерам кода, вы сможете легко создавать модели .tflite, которые позволяют использовать мощные приложения машинного обучения на устройствах с ограниченными ресурсами.