Окунитесь в хэллоуинский дух GAN с этим учебным пособием по генератору тыкв

(Генеративное состязательное обучение. Синие линии обозначают поток входных данных, зеленые — выходные данные, а красные — сигналы ошибок.)

Генеративно-состязательные сети, или сокращенно GAN, являются одной из самых захватывающих областей глубокого обучения, появившихся за последние 10 лет. Это, в частности, утверждает Янн ЛеКун из MNIST и известный специалист по обратному распространению. Быстрый прогресс с момента введения GAN в 2014 году Яном Гудфеллоу и другими отмечает противоборствующую подготовку как прорывную идею, в которой есть потенциал изменить общество полезным, гнусным и глупым образом. Обучение GAN использовалось для всего, от предсказуемых кошачьих генераторов до вымышленных портретистов, нарисованных GAN, которые продаются за шестизначные суммы на аукционах изобразительного искусства. Все GAN основаны на простой предпосылке дуэльных сетей: творческая сеть, которая генерирует какие-то выходные данные (в нашем случае изображения), и скептическая сеть, которая выводит вероятность того, что данные реальны или сгенерированы. Они известны как сети Генератор и Дискриминатор, и, просто пытаясь помешать друг другу, они могут научиться генерировать реалистичные данные. В этом руководстве мы построим GAN на основе популярной полностью сверточной архитектуры DCGAN и обучим ее производить тыквы для Хэллоуина.

Мы будем использовать PyTorch, но вы также можете использовать TensorFlow (если вас это устраивает). Опыт использования любой из основных библиотек глубокого обучения стал поразительно схожим, учитывая изменения в TensorFlow в выпуске 2.0 этого года, которые мы видели как часть более широкой конвергенции популярных фреймворков с динамически исполняемым кодом на языке Python с необязательным оптимизированная компиляция графа для ускорения и развертывания.»

Чтобы настроить и активировать виртуальную среду для основных экспериментов PyTorch:

virtualenv pytorch --python=python3 pytorch/bin/pip install numpy matplotlib torch torchvision source pytorch/bin/activate

И если у вас установлен conda и вы предпочитаете его использовать:

conda new -n pytorch numpy matplotlib torch torchvision conda activate pytorch

И чтобы сэкономить вам время на размышления, вот импорт, который нам понадобится:

import random import time import numpy as np import matplotlib.pyplot as plt import torch import torch.nn as nn import torch.nn.parallel import torch.optim as optim import torch.nn.functional as F import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils

Наша GAN будет основана на архитектуре DCGAN и во многом заимствует официальную реализацию в примерах PyTorch. DC в DCGAN означает Deep Convolutional, а архитектура DCGAN расширила протокол неконтролируемого состязательного обучения, описанный в оригинальной документе GAN Яна Гудфеллоу. Это относительно простая и интерпретируемая сетевая архитектура, которая может стать отправной точкой для тестирования более сложных идей.

Архитектура DCGAN, как и все GAN, фактически состоит из двух сетей: дискриминатора и генератора. Важно, чтобы они были равномерно согласованы с точки зрения их подходящей мощности, скорости обучения и т. д., чтобы избежать несоответствия сетей. Обучение GAN общеизвестно нестабильно, и может потребоваться небольшая настройка, чтобы заставить его работать с данной комбинацией архитектуры набора данных. В этом примере DCGAN легко застрять с генератором, выдающим желто-оранжевую тарабарщину в виде шахматной доски, но не сдавайтесь! В целом, я глубоко восхищаюсь авторами подобных прорывов, когда можно легко разочароваться из-за ранних плохих результатов и может потребоваться героическое терпение. С другой стороны, иногда это просто вопрос достаточной подготовки и хорошей идеи, и все получается всего за несколько дополнительных часов работы и вычислений.

Генератор представляет собой набор транспонированных сверточных слоев, которые преобразуют длинное и тонкое многоканальное тензорное скрытое пространство в полноразмерное изображение. Это показано на следующей диаграмме из документа DCGAN:

Полностью сверточный генератор от Radford et al. 2016.

Мы создадим экземпляр как подкласс класса torch.nn.Module. Это гибкий способ реализации и разработки моделей. Вы можете заполнить функцию класса forward, позволяющую включать такие вещи, как пропуск соединений, которые невозможны с простым экземпляром модели torch.nn.Sequential.

class Generator(nn.Module): def __init__(self, ngpu, dim_z, gen_features, num_channels): super(Generator, self).__init__() self.ngpu = ngpu self.block0 = nn.Sequential(\ nn.ConvTranspose2d(dim_z, gen_features*32, 4, 1, 0, bias=False),\ nn.BatchNorm2d(gen_features*32),\ nn.ReLU(True)) self.block1 = nn.Sequential(\ nn.ConvTranspose2d(gen_features*32,gen_features*16, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features*16),\ nn.ReLU(True)) self.block2 = nn.Sequential(\ nn.ConvTranspose2d(gen_features*16,gen_features*8, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features*8),\ nn.ReLU(True)) self.block3 = nn.Sequential(\ nn.ConvTranspose2d(gen_features*8, gen_features*4, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features*4),\ nn.ReLU(True)) self.block5 = nn.Sequential(\ nn.ConvTranspose2d(gen_features*4, num_channels, 4, 2, 1, bias=False))\ def forward(self, z): x = self.block0(z) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = F.tanh(self.block5(x)) return x

— творческая половина нашего дуэта GAN, и большинство людей склонны сосредотачиваться на изученных способностях создавать, казалось бы, новые изображения. На самом деле генератор бесполезен без хорошо согласованного дискриминатора. Архитектура дискриминатора будет знакома тем из вас, кто в прошлом создавал несколько глубоких сверточных классификаторов изображений. В данном случае это бинарный классификатор, пытающийся отличить подделку от настоящей, поэтому мы используем сигмовидную функцию активации на выходе вместо softmax, которую мы использовали бы для задач с несколькими классами. Мы также избавляемся от любых полносвязных слоев, так как они здесь не нужны.

Полностью сверточный бинарный классификатор, подходящий для использования в качестве дискриминатора D(x).

И код:

class Discriminator(nn.Module): def __init__(self, ngpu, gen_features, num_channels): super(Discriminator, self).__init__() self.ngpu = ngpu self.block0 = nn.Sequential(\ nn.Conv2d(num_channels, gen_features, 4, 2, 1, bias=False),\ nn.LeakyReLU(0.2, True)) self.block1 = nn.Sequential(\ nn.Conv2d(gen_features, gen_features, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features),\ nn.LeakyReLU(0.2, True)) self.block2 = nn.Sequential(\ nn.Conv2d(gen_features, gen_features*2, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features*2),\ nn.LeakyReLU(0.2, True)) self.block3 = nn.Sequential(\ nn.Conv2d(gen_features*2, gen_features*4, 4, 2, 1, bias=False),\ nn.BatchNorm2d(gen_features*4),\ nn.LeakyReLU(0.2, True)) self.block_n = nn.Sequential( nn.Conv2d(gen_features*4, 1, 4, 1, 0, bias=False),\ nn.Sigmoid()) def forward(self, imgs): x = self.block0(imgs) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block_n(x) return x

Нам также понадобится несколько вспомогательных функций для создания загрузчика данных и инициализации весов модели в соответствии с рекомендациями в документе DCGAN. Приведенная ниже функция возвращает загрузчик данных PyTorch с небольшим увеличением изображения, просто укажите ему папку, содержащую ваши изображения. Я работаю с относительно небольшой партией бесплатных изображений от Pixabay, поэтому аугментация изображений важна для получения лучшего результата от каждого изображения.

def get_dataloader(root_path): dataset = dset.ImageFolder(root=root_path,\ transform=transforms.Compose([\ transforms.RandomHorizontalFlip(),\ transforms.RandomAffine(degrees=5, translate=(0.05,0.025), scale=(0.95,1.05), shear=0.025),\ transforms.Resize(image_size),\ transforms.CenterCrop(image_size),\ transforms.ToTensor(),\ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),\ ])) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,\ shuffle=True, num_workers=num_workers) return dataloader

И инициализировать веса:

def weights_init(my_model): classname = my_model.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(my_model.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(my_model.weight.data, 1.0, 0.02) nn.init.constant_(my_model.bias.data, 0.0)

Вот и все функции и классы. Все, что осталось сейчас, — это связать все вместе с помощью скриптов (и беспощадной итерации по гиперпараметрам). Рекомендуется сгруппировать гиперпараметры вместе в верхней части скрипта (или передать их с помощью флагов или argparse), чтобы можно было легко изменить значения.

# ensure repeatability my_seed = 13 random.seed(my_seed) torch.manual_seed(my_seed) # parameters describing the input latent space and output images dataroot = "images/pumpkins/jacks" num_workers = 2 image_size = 64 num_channels = 3 dim_z = 64 # hyperparameters batch_size = 128 disc_features = 64 gen_features = 64 disc_lr = 1e-3 gen_lr = 2e-3 beta1 = 0.5 beta2 = 0.999 num_epochs = 5000 save_every = 100 disp_every = 100 # set this variable to 0 for cpu-only training. This model is lightweight enough to train on cpu in a few hours. ngpu = 2

Затем мы создаем модели и загрузчик данных. Я использовал настройку с двумя графическими процессорами, чтобы быстро оценить несколько различных итераций гиперпараметров. В PyTorch тривиально тренироваться на нескольких графических процессорах, заключая ваши модели в класс torch.nn.DataParallel. Не беспокойтесь, если все ваши графические процессоры заняты поиском искусственного общего интеллекта, эта модель достаточно легкая для обучения работе с ЦП за разумное время (несколько часов).

dataloader = get_dataloader(dataroot) device = torch.device("cuda:0" if ngpu > 0 and torch.cuda.is_available() else "cpu") gen_net = Generator(ngpu, dim_z, gen_features, \ num_channels).to(device) disc_net = Discriminator(ngpu, disc_features, num_channels).to(device) # add data parallel here for >= 2 gpus if (device.type == "cuda") and (ngpu > 1): disc_net = nn.DataParallel(disc_net, list(range(ngpu))) gen_net = nn.DataParallel(gen_net, list(range(ngpu))) gen_net.apply(weights_init) disc_net.apply(weights_init)

Сети генератора и дискриминатора обновляются вместе в одном большом цикле. Прежде чем мы перейдем к этому, нам нужно определить наш критерий потерь (бинарная перекрестная энтропия), определить оптимизаторы для каждой сети и создать несколько списков, которые мы будем использовать для отслеживания прогресса обучения.

criterion = nn.BCELoss() # a set sample from latent space so we can unambiguously monitor training progress fixed_noise = torch.randn(64, dim_z, 1, 1, device=device) real_label = 1 fake_label = 0 disc_optimizer = optim.Adam(disc_net.parameters(), lr=disc_lr, betas=(beta1, beta2)) gen_optimizer = optim.Adam(gen_net.parameters(), lr=gen_lr, betas=(beta1, beta2)) img_list = [] gen_losses = [] disc_losses = [] iters = 0

Цикл обучения

Цикл обучения концептуально прост, но слишком длинный, чтобы его можно было охватить одним фрагментом, поэтому мы разобьем его на несколько частей. Вообще говоря, мы сначала обновляем дискриминатор на основе прогнозов для набора реальных и сгенерированных изображений. Затем мы передаем сгенерированные изображения вновь обновленному дискриминатору и используем выходные данные классификации из D(G(z)) в качестве обучающего сигнала для генератора, используя реальную метку как цель.

Сначала мы войдем в цикл и выполним обновление дискриминатора:

t0 = time.time() for epoch in range(num_epochs): for ii, data in enumerate(dataloader,0): # update the discriminator disc_net.zero_grad() # discriminator pass with real images real_cpu = data[0].to(device) batch_size= real_cpu.size(0) label = torch.full((batch_size,), real_label, device=device) output = disc_net(real_cpu).view(-1) disc_real_loss = criterion(output,label) disc_real_loss.backward() disc_x = output.mean().item() # discriminator pass with fake images noise = torch.randn(batch_size, dim_z, 1, 1, device=device) fake = gen_net(noise) label.fill_(fake_label) output = disc_net(fake.detach()).view(-1) disc_fake_loss = criterion(output, label) disc_fake_loss.backward() disc_gen_z1 = output.mean().item() disc_loss = disc_real_loss + disc_fake_loss disc_optimizer.step()

Обратите внимание, что мы также отслеживаем средние прогнозы для поддельных и реальных партий. Это даст нам простой способ отслеживать, насколько сбалансировано (или нет) наше обучение, сообщая нам, как прогнозы меняются после каждого обновления.

Далее мы обновим генератор на основе предсказаний дискриминатора, используя реальную метку и бинарную кросс-энтропийную потерю. Обратите внимание, что мы обновляем генератор на основе ошибочной классификации дискриминатором поддельных изображений как настоящих. Этот сигнал создает лучшие градиенты для обучения, чем минимизация способности дискриминатора напрямую обнаруживать подделки. Впечатляет то, что GAN в конечном итоге могут научиться создавать фотореалистичный контент на основе этого типа дуэльных сигналов о потерях.

# update the generator gen_net.zero_grad() label.fill_(real_label) output = disc_net(fake).view(-1) gen_loss = criterion(output, label) gen_loss.backward() disc_gen_z2 = output.mean().item() gen_optimizer.step()

Наконец, есть небольшая уборка, чтобы следить за нашими тренировками. Балансировка обучения GAN — это своего рода искусство, и из одних только цифр не всегда очевидно, насколько эффективно обучаются ваши сети, поэтому рекомендуется время от времени проверять качество изображения. С другой стороны, если какое-либо из значений в операторе печати становится либо 0,0, либо 1,0, скорее всего, ваше обучение рухнуло, и было бы неплохо повторить итерацию с новыми гиперпараметрами.

if ii % disp_every == 0: # discriminator pass with fake images, after updating G(z) noise = torch.randn(batch_size, dim_z, 1, 1, device=device) fake = gen_net(noise) output = disc_net(fake).view(-1) disc_gen_z3 = output.mean().item() print("{} {:.3f} s |Epoch {}/{}:\tdisc_loss: {:.3e}\tgen_loss: {:.3e}\tdisc(x): {:.3e}\tdisc(gen(z)): {:.3e}/{:.3e}/{:.3e}".format(iters,time.time()-t0, epoch, num_epochs, disc_loss.item(), gen_loss.item(), disc_x, disc_gen_z1, disc_gen_z2, disc_gen_z3)) disc_losses.append(disc_loss.item()) gen_losses.append(gen_loss.item()) if (iters % save_every == 0) or \ ((epoch == num_epochs-1) and (ii == len(dataloader)-1)): with torch.no_grad(): fake = gen_net(fixed_noise).detach().cpu() img_list.append(vutils.make_grid(fake, padding=2, normalize=True).numpy()) np.save("./gen_images.npy", img_list) np.save("./gen_losses.npy", gen_losses) np.save("./disc_losses.npy", disc_losses) torch.save(gen_net.state_dict(), "./generator.h5") torch.save(disc_net.state_dict(), "./discriminator.h5") iters += 1

Может потребоваться немного усилий, чтобы получить сносные результаты, но, к счастью для нас, небольшие сбои в реальности на самом деле предпочтительнее для усиления жуткости. Описанный здесь код можно улучшить, но он должен подойти для достаточно правдоподобных фонарей из тыкв с низким разрешением.

Прогресс обучения примерно через 5000 эпох обновлений.

Надеемся, что приведенный выше учебник послужил для того, чтобы подогреть ваш аппетит к GAN, поделкам на Хэллоуин или к тому и другому. После освоения базовой DCGAN, которую мы построили здесь, поэкспериментируйте с более сложными архитектурами и приложениями. Обучение GAN по-прежнему остается искусной наукой, и сбалансировать обучение сложно. Используйте подсказки из ганхаков и, получив упрощенное доказательство концепции, работающее для вашего набора данных/приложения/идеи, добавляйте за раз только небольшие кусочки сложности. Удачи и хороших тренировок.

Первоначально опубликовано на https://blog.exxactcorp.com 28 октября 2019 г.