Это вторая часть серии моделей подобия Tensorflow. Первая часть доступна по следующей ссылке:



Вторая часть посвящена практическому эксперименту с использованием поиска подобия Tensorflow, в котором поиск выполняется в наборе данных Fashion MNIST. Проект будет обучаться на наборе поездов набора данных Fashion MNIST, а прогнозирование будет выполняться на тестовом наборе. Также будет реализовано небольшое веб-приложение для упрощения оценки модели и взаимодействия с ней.

Проект состоит из 4 основных этапов:

  • Подготовка набора данных
  • Обучение модели
  • Оценка модели (офлайн)
  • Создание интерактивного веб-приложения для запуска системы рекомендаций по продуктам

Шаг 1: Подготовка набора данных

Подготовка набора данных решает следующую задачу:

  • Загрузить данные в локальный каталог
  • Преобразование обучающих и тестовых данных в правильную форму, необходимую для этапа обучения модели.
  • Выберите подмножество тестовых изображений (в тензорном формате), преобразуйте и сохраните изображения в локальный каталог (требуется для веб-приложения).

Шаг 2: Обучение модели

Tensorflow Similarity предоставляет различные способы загрузки и выборки данных для обучения модели. Однако в этом примере используется MultiShotMemorySampler из tensorflow_similarity.sampler, чтобы продемонстрировать возможность загрузки любых пользовательских наборов данных.

x_train = np.load(os.path.join(data_dir, "train_images.npy"))    y_train = np.load(os.path.join(data_dir, "train_labels.npy"))    num_classes = len(np.unique(y_train))
# data sampler that generates balanced batches from fashion-mnist dataset
sampler = MultiShotMemorySampler(
    x_train,
    y_train,
    classes_per_batch=num_classes,  # make sure all classes are available in each batch
)

Далее строится простая архитектура модели для модели сходства. Однако для более сложного набора данных рекомендуется более сложная архитектура модели. Основная идея состоит в том, чтобы добавить MetricEmbedding в качестве выходных данных модели.

# build model architecture
inputs = layers.Input(shape=(28, 28, 1))
x = layers.Rescaling(1 / 255)(inputs)
x = layers.Conv2D(64, 3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = MetricEmbedding(64)(x)
model = SimilarityModel(inputs, outputs)

Поскольку это проблема обучения метрике, требуется другая функция потерь. Следовательно, в этом случае используется MultiSimilarityLoss(distance="cosine").

После обучения модели требуется этап построения индекса, чтобы примеры поездов можно было найти.

model.index(x=x_train, y=y_train, data=x_train)

Шаг 3: Оценка модели (необязательно)

Этот шаг обеспечивает возможность ручной оценки. Он предоставляет возможность случайного выбора подмножества тестовых данных и поиска ближайших примеров поездов.

Шаг 4. Создание интерактивного веб-приложения

Интерактивное веб-приложение разработано с использованием Dash и Plotly. Приложение отображает список изображений, используемых при обучении модели, и позволяет пользователям находить ближайшие изображения из выбранного изображения.

Это работает так: после того, как изображение выбрано пользователем, тензор выбранного изображения будет отправлен обученной модели, и модель вернет индексы элементов с ближайшим расстоянием до тестового тензора.

В качестве примера был выбран пуловер из тестового набора и все возвращенные ближайшие элементы также выглядят как пуловеры.

Мой ключевой вывод

Это просто забавный проект, который был разработан для демонстрации возможностей библиотеки Tensorflow Similarity. Однако систему рекомендаций можно обобщить, чтобы она могла поддерживать различные продукты и варианты использования. Кроме того, вместо обучения автономного набора данных система может быть спроектирована таким образом, чтобы существовал непрерывный поток обучения модели, который позволяет добавлять новые данные в модель.

Традиционно рекомендательную систему можно создать, выполнив поиск K-ближайших соседей (KNN) (либо по необработанным функциям, либо по внедренным функциям). Однако эти подходы требуют тщательной разработки признаков, поскольку KNN — это отдельный шаг, если встраивание признаков используется в качестве входных данных для задачи поиска. Кроме того, существуют проблемы с использованием KNN для многомерных объектов, таких как изображения (проклятие размерности https://en.wikipedia.org/wiki/Curse_of_Dimensionity). Использование подобия Tensorflow явно поможет сократить инженерные усилия, поскольку обучение модели глубокого обучения и поиск ближайших соседей включаются в один шаг. Более того, обратная связь от поиска ближайших соседей уже будет учитываться в потерях при обучении, и, следовательно, ожидается, что это поможет улучшить производительность системы. Кроме того, более простой конвейер поможет сократить значительные затраты на проектирование.

Код проекта доступен по адресу https://github.com/att288/tf-similarity-fashion-mnist. Есть раздел развертывания докера, который помогает облегчить запуск проекта. После выполнения команды docker веб-приложение должно быть доступно по адресу http://localhost:8050.