Создание простейшей сети GAN в PyTorch

Я долго делал GAN в TensorFlow / Keras. Честно говоря, слишком долго, потому что изменить сложно. Это потребовало некоторой убедительности, но в конце концов я укусил пулю и переключился на PyTorch. К сожалению, большинство руководств по PyTorch GAN, с которыми мне приходилось сталкиваться, были чрезмерно сложными, сосредоточены больше на теории GAN, чем на применении, или, как ни странно, не имели смысла. Чтобы исправить это, я написал этот микро-учебник по созданию ванильного GAN в PyTorch с упором на PyTorch. Сам код доступен здесь (обратите внимание, что код github и суть в этом руководстве немного отличаются). Я рекомендую открыть это руководство в двух окнах, в одном из которых отображается код, а в другом - пояснения.

Требования

  1. Python 3.7 или выше. Если немного ниже, вам придется реорганизовать f-строки.
  2. PyTorch 1.5 Не знаете, как его установить? Это может помочь.
  3. Семнадцать или восемнадцать минут вашего времени. Всего лишь двенадцать, если ты умен.

Задача под рукой

Создайте функцию G: Z → X, где Z ~ U (0, 1) и X ~ N (0, 1).

На английском языке это означает создать GAN, который аппроксимирует« нормальное распределение с учетом равномерного случайного шума на входе». Это означает, что входными данными для GAN будет одно число, как и выходными данными. Обратите внимание, что для простоты мы будем использовать здесь функцию генерации данных вместо обучающего набора данных.

Давай просто перейдем к делу

Убедитесь, что у вас установлена ​​правильная версия Python, и установите PyTorch. Затем создайте новый файл vanilla_GAN.py и добавьте следующий импорт:

import torch
from torch import nn
import torch.optim as optim

Наш сценарий GAN будет состоять из трех компонентов: сети генератора, сети дискриминатора и самой сети GAN, в которой размещаются и обучаются две сети. Начнем с генератора:

Генератор

Добавьте в свой скрипт следующее:

Наш класс Generator наследуется от класса PyTorch nn.Module, который является базовым классом для модулей нейронной сети. Короче говоря, он сообщает PyTorch: «Это нейронная сеть». У нашего класса генератора есть два метода:

Генератор .__ init__

Инициализируйте объект. Во-первых, это вызывает метод nn.Module __init__ с использованием super. Затем он создает подмодули (то есть слои) и назначает их как переменные экземпляра. Это включает:

  • Линейный (т. Е. Полностью связанный, т. Е. Плотный) слой с входной шириной latent_dim и выходной шириной 64.
  • Линейный слой с входной шириной 64 и выходной шириной 32.
  • Линейный слой с входной шириной 32 и выходной шириной 1.
  • Активация LeakyReLU.
  • Указанная активация выхода.

Поскольку эти модули сохраняются как переменные экземпляра в классе, который наследуется от nn.Module, PyTorch может отслеживать их, когда приходит время обучать сеть; подробнее об этом позже.

Генератор. Вперед

Метод forward необходим для любого класса, унаследованного от nn.Module, поскольку он определяет структуру сети. PyTorch использует структуру define-by-run, что означает, что вычислительный граф нейронной сети строится автоматически по мере того, как вы объединяете простые вычисления вместе. Все это очень питонично. В нашем forward методе мы перебираем модули генератора и применяем их к выходным данным предыдущего модуля, возвращая окончательный результат. Когда вы запускаете сеть (например: prediction = network(data), метод forward - это то, что вызывается для вычисления выходных данных.

Генератор. Назад

Неа! PyTorch использует Autograd для автоматического распознавания; когда вы запускаете метод forward, PyTorch автоматически отслеживает вычислительный граф, и, следовательно, вам не нужно указывать ему, как распространять градиенты в обратном направлении. Как это выглядит на практике? Продолжай читать.

Дискриминатор

Добавьте в свой скрипт следующее:

Наш объект Discriminator будет почти идентичен нашему генератору, но, глядя на класс, вы можете заметить два отличия. Во-первых, сеть была параметризована и немного переработана, чтобы сделать ее более гибкой. Во-вторых, функция вывода была зафиксирована на Sigmoid, потому что дискриминатору будет поручено классифицировать образцы как реальные (1) или сгенерированные (0).

Дискриминатор .__ init__

Дискриминатор __init__ метод выполняет три функции. Опять же, он вызывает метод nn.Module __init__, используя super. Затем он сохраняет входной размер как объектную переменную. Наконец, он вызывает метод _init_layers. В качестве аргументов __init__ принимает входное измерение и список целых чисел с именем layers, который описывает ширину nn.Linear модулей, включая выходной слой.

Дискриминатор._init_layers

Этот метод создает экземпляры сетевых модулей. Тело этого метода можно было бы поместить в __init__, но я считаю более понятным, если бы шаблон инициализации объекта был отделен от кода построения модуля, особенно по мере роста сложности сети. Этот метод выполняет итерацию по аргументу layers и создает список модулей nn.Linear подходящего размера, а также активацию Leaky ReLU после каждого внутреннего слоя и активацию сигмоида после последнего слоя. Эти модули хранятся в объекте ModuleList, который функционирует как обычный список Python, за исключением того факта, что PyTorch распознает его как список модулей, когда приходит время обучать сеть. Существует также ModuleDict класс, который служит той же цели, но работает как словарь Python; подробнее об этом позже.

Дискриминатор. Вперед

Метод forward функционирует так же, как его коллега в генераторе. Однако, поскольку мы сохранили наши модули в виде списка, мы можем просто перебирать этот список, применяя каждый модуль по очереди.

VanillaGAN

Добавьте в свой скрипт следующее:

Наш класс VanillaGAN содержит объекты Generator и Discriminator и обрабатывает их обучение.

VanillaGAN .__ init__

В качестве входных данных конструктор VanillaGAN принимает:

  • Объект Generator.
  • Объект-дискриминатор.
  • Функция шума. Это функция, используемая для выборки скрытых векторов Z, которые наш генератор будет отображать на сгенерированные выборки X. Эта функция должна принимать целое число num в качестве входных данных и возвращать тензор 2D Torch с формой (num, latent_dim).
  • Функция данных. Это функция, которую наш Генератор должен изучить. Эта функция должна принимать целое число num в качестве входных данных и возвращать тензор 2D Torch с формой (num, data_dim), где data_dim - размер данных, которые мы пытаемся сгенерировать, input_dim нашего Дискриминатора.
  • По желанию, размер обучающей мини-партии.
  • По желанию устройство. Это может быть cpu или cuda, если вы хотите использовать графический процессор.
  • По желанию, скорость обучения для генератора и дискриминатора.

Где необходимо, эти аргументы сохраняются как переменные экземпляра.

Целью GAN является потеря двоичной кросс-энтропии (nn.BCELoss), которую мы создаем и назначаем как объектную переменную criterion.

В нашей сети GAN используются два оптимизатора: один для генератора, а другой - для дискриминатора. Давайте разберем оптимизатор Генератора, экземпляр Adam. Оптимизаторы управляют обновлением параметров нейронной сети с учетом градиентов. Для этого оптимизатору необходимо знать, какие параметры его должны учитывать; в данном случае это discriminator.parameters(). Пару минут назад я сказал тебе

PyTorch может отслеживать [модули], когда приходит время обучать сеть.

Поскольку объект Discriminator наследуется от nn.Module, он наследует метод parameters, который возвращает все обучаемые параметры во всех модулях, заданных как переменные экземпляра для Discriminator (поэтому нам пришлось использовать nn.ModuleList вместо списка Python, чтобы PyTorch знал, что проверьте каждый элемент на наличие параметров). Оптимизатору также дается указанная скорость обучения и бета-параметры, которые хорошо работают для GAN. Оптимизатор генератора работает таким же образом, за исключением того, что вместо этого он отслеживает параметры генератора и использует немного меньшую скорость обучения.

Наконец, мы сохраняем вектор-столбец из единиц и вектор-столбец из нулей в качестве меток классов для обучения, чтобы нам не приходилось повторно создавать их повторно.

VanillaGAN.generate_samples

Это вспомогательная функция для получения случайных выборок из генератора. Вызывается без аргументов, генерирует batch_size выборки. Это можно изменить, указав аргумент num для создания num выборок или предоставив ему двумерный тензор PyTorch, содержащий указанные скрытые векторы. no_grad диспетчер контекста говорит PyTorch не беспокоиться об отслеживании градиентов здесь, уменьшая объем вычислений.

VanillaGAN.train_step_generator

Эта функция выполняет один шаг обучения Генератора и возвращает потерю как число с плавающей запятой. Наряду с этапом обучения дискриминатора, это суть алгоритма, поэтому давайте рассмотрим его построчно:

self.generator.zero_grad()

Очистите градиенты. Самое крутое в PyTorch - это то, что градиент автоматически накапливается в каждом параметре по мере использования сети. Однако обычно мы хотим очищать эти градиенты между каждым шагом оптимизатора; метод zero_grad делает именно это.

latent_vec = self.noise_fn(self.batch_size)

Пример batch_size скрытых векторов из функции генерации шума. Легкий.

generated = self.generator(latent_vec)

Подайте скрытые векторы в Генератор и получите сгенерированные образцы в качестве вывода (под капотом здесь вызывается метод generator.forward). Помните, что PyTorch - это определение по запуску, так что это точка, где строится вычислительный граф генератора.

classifications = self.discriminator(generated)

Загрузите сгенерированные образцы в Дискриминатор и убедитесь, что каждый образец настоящий. Помните, что Дискриминатор пытается классифицировать эти образцы как поддельные (0), в то время как Генератор пытается обманом заставить его думать, что они настоящие (1). Как и в предыдущей строке, здесь строится вычислительный граф Дискриминатора, и поскольку ему были предоставлены сгенерированные выборки generated в качестве входных данных, этот вычислительный граф застревает на конце вычислительного графа Генератора.

loss = self.criterion(classifications, self.target_ones)

Рассчитайте потери для генератора. Наша функция потерь - это двоичная кросс-энтропия, поэтому потери для каждого из batch_size отсчетов вычисляются и усредняются в одно значение. loss - тензор PyTorch с единственным значением в нем, поэтому он по-прежнему связан с полным вычислительным графом.

loss.backward()

Вот где происходит волшебство. Или, скорее, здесь происходит престиж, поскольку все это время магия происходила незримо. Здесь метод backward вычисляет градиент d_loss / d_x для каждого параметра x в вычислительном графике.

self.optim_g.step()

Примените один шаг оптимизатора, смещая каждый параметр вниз по градиенту. Если вы раньше создавали GAN в Керасе, вы, вероятно, знакомы с необходимостью устанавливать my_network.trainable = False. Одним из преимуществ PyTorch является то, что вам не нужно беспокоиться об этом, потому что optim_g было сказано заботиться только о параметрах нашего генератора.

return loss.item()

Верните потерю. Мы будем хранить их в списке для дальнейшей визуализации. Однако жизненно важно, чтобы мы использовали метод item, чтобы вернуть его как число с плавающей запятой, не как тензор PyTorch. Это связано с тем, что, если мы сохраним ссылку на этот тензорный объект в списке, Python также сохранит весь вычислительный граф. Это большая трата памяти, поэтому нам нужно убедиться, что мы сохраняем только то, что нам нужно (значение), чтобы сборщик мусора Python мог очистить все остальное.

VanillaGAN.train_step_discriminator

Как и предыдущий метод, train_step_discriminator выполняет один шаг обучения дискриминатора. Давайте пройдемся по нему строка за парой:

self.discriminator.zero_grad()

Вы знаете это!

# real samples
real_samples = self.data_fn(self.batch_size)
pred_real = self.discriminator(real_samples)
loss_real = self.criterion(pred_real, self.target_ones)

Возьмите несколько реальных образцов из целевой функции, получите уверенность Дискриминатора в том, что они реальны (Дискриминатор хочет максимизировать это!), И рассчитайте потери. Это очень похоже на шаг обучения генератора.

# generated samples
latent_vec = self.noise_fn(self.batch_size)
with torch.no_grad():
    fake_samples = self.generator(latent_vec)
pred_fake = self.discriminator(fake_samples)
loss_fake = self.criterion(pred_fake, self.target_zeros)

Сделайте выборку из нескольких сгенерированных сэмплов из генератора, получите уверенность Дискриминатора в их реальности (Дискриминатор хочет минимизировать это!) И вычислить потери. Поскольку здесь мы обучаем Дискриминатор, нас не волнуют градиенты в Генераторе, и поэтому мы используем no_grad диспетчер контекста. В качестве альтернативы вы можете отказаться от no_grad и заменить его в строке pred_fake = self.discriminator(fake_samples.detach()) и отделить fake_samples от вычислительного графика Генератора постфактум, но зачем вообще его вычислять?

# combine
loss = (loss_real + loss_fake) / 2

Усредните расчетные графики для реальных образцов и сгенерированных образцов. Да, это действительно так. Это моя любимая строка во всем скрипте, потому что PyTorch может комбинировать обе фазы вычислительного графа, используя простую арифметику Python.

loss.backward()
self.optim_d.step()
return loss_real.item(), loss_fake.item()

Вычислите градиенты, примените один шаг градиентного спуска и верните потери.

VanillaGAN.train_step

Этот метод просто применяет один шаг обучения дискриминатора и один шаг генератора, возвращая потери в виде кортежа.

Собираем все вместе

Добавьте в свой скрипт следующее:

Функция main говорит сама за себя, но давайте рассмотрим ее вместе для полноты картины.

  • Мы импортируем time, потому что обычно рекомендуется проводить обучение нейронной сети.
  • У нас будет 600 эпох по 10 партий в каждой; пакеты и эпохи здесь не нужны, поскольку мы используем истинную функцию вместо набора данных, но давайте придерживаться соглашения для психологического удобства.
  • Мы создаем экземпляр генератора и дискриминатора. Помните, что мы должны указать ширину слоя Дискриминатора.
  • Мы определяем функцию шума как случайные, однородные значения в [0, 1], выраженные как вектор-столбец. Мы указываем устройство как «cpu», но это может быть «CUDA», если у вас это настроено. Обратите внимание: если вы используете здесь cuda, используйте его для целевой функции и VanillaGAN.
  • Мы определяем целевую функцию как случайные, нормальные (0, 1) значения, выраженные как вектор-столбец. Опять же, мы указываем устройство как «cpu».
  • Мы создаем экземпляр VanillaGAN.
  • Мы настроили списки, чтобы отслеживать потери и запустить цикл обучения, выводя статистику обучения после каждой эпохи.

Вот и все! Поздравляем, вы написали свой первый GAN на PyTorch. Я не включил код визуализации, но вот как изученное распределение G выглядит после каждого шага обучения:

А вот потери за эпоху:

Заключительные мысли

Поскольку это руководство было посвящено созданию классов GAN и цикла обучения в PyTorch, о реальной сетевой архитектуре уделялось мало внимания. Современные хаки GAN не использовались, и поэтому окончательное распределение лишь приблизительно напоминает истинное стандартное нормальное распределение. Если вы хотите узнать больше о GAN, попробуйте настроить гиперпараметры и модули; соответствуют ли результаты вашим ожиданиям?

Обычно у нас нет доступа к истинному распределению, генерирующему данные (в противном случае нам не понадобился бы GAN!). В последующем руководстве к этому мы будем реализовывать сверточную GAN, которая использует реальный целевой набор данных вместо функции.

Все не цитируемые изображения являются моими собственными. Не стесняйтесь использовать их, но, пожалуйста, процитируйте эту статью ❤️