Введение в cycleGAN с использованием pytorch и перевода изображений с заменой пола
Краткая история…
GAN или генеративно-состязательная сеть была представлена в рамках исследовательской работы Яна Гудфеллоу в 2014 году. В этой статье он первоначально предложил генерировать новые данные из существующего набора данных с использованием конкурирующих нейронных сетей. В 2017 году, основываясь на этом фундаменте, другая группа или исследователи (Jun-Yan Zhu, Taesung Park, Phillip Isola, Alexei A. Efros) использовали этот метод для создания собственного цикла GAN для преобразования изображений. . Один из наиболее классических примеров — превращение лошадей в зебр и наоборот. Из этих работ предмет GAN продолжал развиваться с различными подходами и стилями использования GAN.
Ваш первый GAN
Еще до того, как использовать GAN для преобразования изображения в изображение, он использовался в других областях для создания новых данных с помощью существующих данных. Эти наборы данных будут совершенно разными, но близкими друг к другу, где они могут или не могут быть отличимы друг от друга.
Случаи использования
Несмотря на популяризацию этой техники с изображениями, существуют также приложения, не связанные с изображениями. Это можно использовать для обнаружения новых белков при поиске новой вакцины. Правоохранительные органы также могут использовать эту технологию для выявления подозреваемых и
Цель этой статьи
Предполагая, что вы видели GAN и cycleGAN в нескольких источниках, я попытаюсь провести вас через процесс создания cycleGAN с минимальным объяснением, но вместо этого покажу вам, как его создать. Мы не хотим усложнять процесс, чтобы доктора и не доктора наук могли создать свою собственную GAN.
Теперь мы начинаем…
Наряду с созданием нескольких функций и классов, ниже приведены некоторые требования, необходимые для создания цикла GAN.
- Библиотеки Python
2. Набор данных изображения
3. CUDA — для использования графического процессора
4. Загрузчик данных
5. Трансформатор
6. Функция весов
7. Resnet или остаточная обучающая сеть
8. 4 искусственные нейронные сети (2 генератора, 2 дискриминатора)
9. Функция потерь
10. Оптимизаторы
11. Автоэнкодер
Библиотеки Python:
import os, random, cv2, glob, itertools, shutil import numpy as np import pandas as pd import matplotlib.pyplot as plt import matplotlib.image as mpimg from PIL import Image import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data from torch.autograd import Variable import torchvision from torchvision import datasets, transforms, utils from tqdm.notebook import tqdm import warnings warnings.filterwarnings('ignore')
Набор данных изображения
Я использовал набор изображений UTKFace. Этот набор данных был предоставлен нам в открытый доступ Ян Сонгом и Чжифэй Чжаном.
Помечено около 65 000 изображений. Помимо изображений, также были предоставлены файлы ориентиров. Для целей этого проекта мы не будем использовать какие-либо файлы ориентиров/указатели.
Просмотр изображений
для я в диапазоне (5):
файл = np.random.choice(os.listdir(UTKFace_folder))
image_path= os.path.join (UTKFace_folder, файл)
изображение = mpimg.imread (путь_изображения)
plt.subplot(1,10,i+1)
plt.imshow(изображение)
Изменение размера изображений
Исходные изображения имеют размер 200x200. Мы можем сначала изменить размер этих изображений до 128x128 для более быстрого обучения. Это также можно сделать в разделе преобразования, но для упрощения мы изменим их размер здесь. Приведенный ниже код изменит размер изображений по тому же пути, по которому расположены текущие изображения. Нет необходимости создавать еще одну папку, кроме как в целях резервного копирования.
#Gathering all the images and combining them to one all_im = glob(os.path.join(basepath, '**/*')) for im in tqdm(all_im): image = Image.open(im) image = image.resize((128,128), Image.BILINEAR) image.save(im)
Подготовка комплекта для обучения и тестирования.
Нам нужно будет создать 4 набора изображений. Обучающий набор и тестовый набор исходных изображений, в данном случае мужских изображений. Вторым набором будут изображения, в которые мы хотим их преобразовать, также в этом случае набор данных «Женские изображения».
При подготовке изображений одним из наиболее важных факторов является обеспечение того, чтобы количество изображений было одинаковым для каждого набора для обучения и набора для тестирования. В этом случае количество для каждого набора следующее:
TrainA_Male: 7 549 изображений
TestA_Male: 463 изображения
TrainB_Female:7, 549 изображений
TestB_Female: 463 изображения
Чтобы ускорить обучение и посмотреть, работает ли эта модель, я принял во внимание только изображения с возрастом от 20 до 70 лет. Все более молодые или старые изображения были удалены.
Параметры:
Глобальные параметры также созданы для стандартизации и облегчения запоминания того, какие назначения были сделаны. Следите за переменными, которые вы назначаете здесь, так как это может быть источником некоторых проблем, с которыми вы можете столкнуться.
parameters = {'ngf':32, 'ndf':64, 'num_epochs':100, 'decay_epoch': 100, 'lgG':0.0002, 'lgD':0.0002, 'beta1':.5, 'beta2':.9999, 'lambdaA':10, 'lambdaB':10, 'batch_size':32, 'pool_size':50, 'img_width':256, 'img_height':256, 'img_depth':3, 'input_size':128, 'resize_size':128, 'crop_size':256, 'fliplr':True, 'num_resnet':6, 'num_workers':2}
Использование графического процессора
Поскольку мы также хотим использовать GPU для более быстрого процесса, мы будем использовать CUDA со следующим кодом.
def to_numpy(x): return x.data.cpu().numpy() #use GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Трансформер
Преобразователь преобразует изображения в тензор и нормализует их. Обратите внимание, что мы сначала преобразуем изображение в тензор, прежде чем нормализовать его.
def to_numpy(x): return x.data.cpu().numpy() #use GPU device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Преобразователь также является отличным инструментом, который мы можем использовать для добавления дополнительных данных. Нейронные сети хорошо работают с большим количеством данных. Вы можете нажать здесь для получения дополнительной информации о трансформерах.
Создание функции пула изображений
Функция пула изображений будет использоваться при извлечении изображений из пула поддельных изображений на более позднем этапе.
class ImagePool(): def __init__(self, pool_size): self.pool_size = pool_size if self.pool_size > 0: self.num_imgs = 0 self.images = [] def query(self, images): if self.pool_size == 0: return images return_images = [] for image in images.data: image = torch.unsqueeze(image, 0) if self.num_imgs < self.pool_size: self.num_imgs = self.num_imgs + 1 self.images.append(image) return_images.append(image) else: p = random.uniform(0, 1) if p > 0.5: random_id = random.randint(0, self.pool_size-1) tmp = self.images[random_id].clone() self.images[random_id] = image return_images.append(tmp) else: return_images.append(image) return_images = Variable(torch.cat(return_images, 0)) return return_images
Нанесение изображений
Чтобы получить изображения, нам также понадобится функция для построения результатов.
def plot_train_result(real_image, gen_image, recon_image, epoch, save=False, show=True, fig_size=(15, 15)): fig, axes = plt.subplots(2, 3, figsize=fig_size) imgs = [to_numpy(real_image[0]), to_numpy(gen_image[0]), to_numpy(recon_image[0]), to_numpy(real_image[1]), to_numpy(gen_image[1]), to_numpy(recon_image[1])] for ax, img in zip(axes.flatten(), imgs): ax.axis('off') img = img.squeeze() img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8) ax.imshow(img, cmap=None, aspect='equal') plt.subplots_adjust(wspace=0, hspace=0) title = 'Epoch {0}'.format(epoch + 1) fig.text(0.5, 0.04, title, ha='center')
Загрузчик данных
Поскольку мы не сможем загрузить весь набор изображений, так как это может занять огромное количество времени, мы будем использовать загрузчик данных для загрузки изображений наборами. Это более эффективный способ загрузки наборов данных, поскольку наблюдается значительное увеличение объема данных.
class DatasetFromFolder(data.Dataset): def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False): super(DatasetFromFolder, self).__init__() self.input_path = os.path.join(image_dir, subfolder) self.image_filenames = [x for x in sorted(os.listdir(self.input_path))] self.file_num = len(self.image_filenames) self.transform = trans self.resize_scale = resize_scale self.crop_size = crop_size self.fliplr = fliplr def __len__(self): return len(self.image_filenames) def __getitem__(self, index): # Load Images img_fn = os.path.join(self.input_path, self.image_filenames[index]) img = Image.open(img_fn).convert('RGB') #preprocessing if self.resize_scale: img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) if self.crop_size: x = random.randint(0, self.resize_scale - self.crop_size + 1) y = random.randint(0, self.resize_scale - self.crop_size + 1) img = img.crop((x, y, x + self.crop_size, y + self.crop_size)) if self.fliplr: if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.transform is not None: img = self.transform(img) return img
Загрузка DataLoader из папки
class DatasetFromFolder(data.Dataset): def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False): super(DatasetFromFolder, self).__init__() self.input_path = os.path.join(image_dir, subfolder) self.image_filenames = [x for x in sorted(os.listdir(self.input_path))] self.file_num = len(self.image_filenames) self.transform = trans self.resize_scale = resize_scale self.crop_size = crop_size self.fliplr = fliplr def __len__(self): return len(self.image_filenames) def __getitem__(self, index): # index = total_index % self.file_num # Load Images img_fn = os.path.join(self.input_path, self.image_filenames[index]) img = Image.open(img_fn).convert('RGB') #preprocessing if self.resize_scale: img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) if self.crop_size: x = random.randint(0, self.resize_scale - self.crop_size + 1) y = random.randint(0, self.resize_scale - self.crop_size + 1) img = img.crop((x, y, x + self.crop_size, y + self.crop_size)) if self.fliplr: if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.transform is not None: img = self.transform(img) return img
Сверточная сеть
Ресурсы CNN:
https://en.wikipedia.org/wiki/Сверточная_нейронная_сеть
class DatasetFromFolder(data.Dataset): def __init__(self, image_dir, subfolder='train', transform=None, resize_scale=None, crop_size=None, fliplr=False): super(DatasetFromFolder, self).__init__() self.input_path = os.path.join(image_dir, subfolder) self.image_filenames = [x for x in sorted(os.listdir(self.input_path))] self.file_num = len(self.image_filenames) self.transform = trans self.resize_scale = resize_scale self.crop_size = crop_size self.fliplr = fliplr def __len__(self): return len(self.image_filenames) def __getitem__(self, index): # index = total_index % self.file_num # Load Images img_fn = os.path.join(self.input_path, self.image_filenames[index]) img = Image.open(img_fn).convert('RGB') #preprocessing if self.resize_scale: img = img.resize((self.resize_scale, self.resize_scale), Image.BILINEAR) if self.crop_size: x = random.randint(0, self.resize_scale - self.crop_size + 1) y = random.randint(0, self.resize_scale - self.crop_size + 1) img = img.crop((x, y, x + self.crop_size, y + self.crop_size)) if self.fliplr: if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) if self.transform is not None: img = self.transform(img) return img
Деконволюционная сеть
ресурсы деконволюции:
https://datascience.stackexchange.com/questions/6107/what-are-deconvolutional-layers
class DeconvBlock(torch.nn.Module): #Initialization Function def __init__(self, input_size, output_size, kernel_size=3, stride=2, padding=1, output_padding=1, activation='relu', batch_norm=True): super(DeconvBlock,self).__init__() self.deconv = torch.nn.ConvTranspose2d(input_size, output_size, kernel_size, stride, padding, output_padding) self.batch_norm = batch_norm self.bn = torch.nn.InstanceNorm2d(output_size) self.activation = activation self.relu = torch.nn.ReLU(True) #Forward pass def forward(self,x): if self.batch_norm: out = self.bn(self.deconv(x)) else: out = self.deconv(x) if self.activation == 'relu': return self.relu(out) elif self.activation == 'lrelu': return self.lrelu(out) elif self.activation == 'tanh': return self.tanh(out) elif self.activation == 'no_act': return out
Resnet или остаточный обучающий блок/сеть
Ресурсы Ренета:
https://towardsdatascience.com/an-overview-of-resnet-and-its-variants-5281e2f56035
class ResnetBlock(torch.nn.Module): def __init__(self,num_filter,kernel_size=3,stride=1,padding=0): super(ResnetBlock,self).__init__() conv1 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding) conv2 = torch.nn.Conv2d(num_filter, num_filter, kernel_size, stride, padding) bn = torch.nn.InstanceNorm2d(num_filter) relu = torch.nn.ReLU(True) pad = torch.nn.ReflectionPad2d(1) self.resnet_block = torch.nn.Sequential( pad, conv1, bn, relu, pad, conv2, bn ) def forward(self,x): out = self.resnet_block(x) return outThe Generator
Генератор
Теперь самое интересное: теперь мы создаем генератор, который будем использовать в учебном блоке позже.
class Generator(torch.nn.Module): def __init__(self,input_dim,num_filter,output_dim,num_resnet): super(Generator,self).__init__() #Reflection padding self.pad = torch.nn.ReflectionPad2d(3) #Encoder # Input Layer - NN Layer 1 self.conv1 = ConvBlock(input_dim, num_filter, kernel_size=7, stride=1, padding=0) # NN Layer 2 self.conv2 = ConvBlock(num_filter, num_filter*2) ## NN Layer 3 self.conv3 = ConvBlock(num_filter*2, num_filter*4) #Transformer self.resnet_blocks = [] for i in range(num_resnet): self.resnet_blocks.append(ResnetBlock(num_filter*4)) self.resnet_blocks = torch.nn.Sequential(*self.resnet_blocks) #Decoder self.deconv1 = DeconvBlock(num_filter*4, num_filter*2) self.deconv2 = DeconvBlock(num_filter*2, num_filter) self.deconv3 = ConvBlock(num_filter, output_dim, kernel_size=7, stride=1, padding=0, activation='tanh', batch_norm=False) #Forward Function def forward(self,x): #Encoder enc1 = self.conv1(self.pad(x)) enc2 = self.conv2(enc1) enc3 = self.conv3(enc2) #Transformer res = self.resnet_blocks(enc3) #Decoder dec1 = self.deconv1(res) dec2 = self.deconv2(dec1) out = self.deconv3(self.pad(dec2)) return out #Creating weights inside the Generator def normal_weight_init(self,mean=0.0,std=0.02): for m in self.children(): if isinstance(m,ConvBlock): torch.nn.init.normal_(m.conv.weight,mean,std) if isinstance(m,DeconvBlock): torch.nn.init.normal_(m.deconv.weight,mean,std) if isinstance(m,ResnetBlock): torch.nn.init.normal_(m.conv.weight,mean,std) torch.nn.init.constant_(m.conv.bias,0)
Генератор А
Generator_A = Generator(3, parameters['ngf'], 3, parameters['num_resnet']).cuda() Generator_A.normal_weight_init(mean=0.0, std=0.02)
Генератор Б
Generator_A = Generator(3, parameters['ngf'], 3, parameters['num_resnet']).cuda() Generator_A.normal_weight_init(mean=0.0, std=0.02)
Generator_B = Generator(3, параметры[‘ngf’], 3, параметры[‘num_resnet’]).cuda() Generator_B.normal_weight_init(mean=0.0, std=0.02) Generator_B
Создание дискриминатора
class Discriminator(torch.nn.Module): def __init__(self,input_dim,num_filter,output_dim): super(Discriminator,self).__init__() #Input - NN Layer 1 conv1 = ConvBlock(input_dim,num_filter,kernel_size=4,stride=2,padding=1,activation='lrelu',batch_norm=False) #NN Layer 2 conv2 = ConvBlock(num_filter,num_filter*2,kernel_size=4,stride=2,padding=1,activation='lrelu') #NN Layer 3 conv3 = ConvBlock(num_filter*2,num_filter*4,kernel_size=4,stride=2,padding=1,activation='lrelu') #NN Layer 4 conv4 = ConvBlock(num_filter*4,num_filter*8,kernel_size=4,stride=1,padding=1,activation='lrelu') #Output - NN Layer 5 conv5 = ConvBlock(num_filter*8,output_dim,kernel_size=4,stride=1,padding=1,activation='no_act',batch_norm=False) self.conv_blocks = torch.nn.Sequential( conv1, conv2, conv3, conv4, conv5 ) def forward(self,x): out = self.conv_blocks(x) return out #Creating the weights for the Discriminator def normal_weight_init(self, mean=0.0, std=0.02): for m in self.children(): if isinstance(m, ConvBlock): torch.nn.init.normal_(m.conv.weight.data, mean, std)
Дискриминатор А
Discriminator_A = Discriminator(3, parameters['ndf'], 1).cuda() Discriminator_A.normal_weight_init(mean=0.0, std=0.02)
Дискриминатор Б
Discriminator_B = Discriminator(3, parameters['ndf'], 1).cuda() Discriminator_B.normal_weight_init(mean=0.0, std=0.02)
Обучающий блок/загрузчики данных
train_data_A = DatasetFromFolder(basepath, subfolder='trainA_Male', transform=trans, resize_scale=parameters['resize_size'], #crop_size=parameters['crop_size'], #fliplr=parameters['fliplr'] ) train_data_loader_A = torch.utils.data.DataLoader(dataset=train_data_A, batch_size=parameters['batch_size'], num_workers=parameters['num_workers'], pin_memory=True, shuffle=True) train_data_B = DatasetFromFolder(basepath, subfolder='trainB_Female', transform=trans, resize_scale=parameters['resize_size'], #crop_size=parameters['crop_size'], #fliplr=True ) train_data_loader_B = torch.utils.data.DataLoader(dataset=train_data_B, batch_size=parameters['batch_size'], num_workers=parameters['num_workers'], pin_memory=True, shuffle=True) #Load test data test_data_A = DatasetFromFolder(basepath, subfolder='testA_Male', transform=trans) test_data_loader_A = torch.utils.data.DataLoader(dataset=test_data_A, batch_size=parameters['batch_size'], shuffle=False) test_data_B = DatasetFromFolder(basepath, subfolder='trainB_Female', transform=trans) test_data_loader_B = torch.utils.data.DataLoader(dataset=test_data_B, batch_size=parameters['batch_size'], shuffle=False)
Проверка реальных данных A и B
test_real_A_data = train_data_A.__getitem__(11).unsqueeze(0)
test_real_B_data = train_data_B.__getitem__(11).unsqueeze(0)
Создание оптимизатора
В нашем случае мы будем использовать один из самых популярных и простых оптимизаторов — оптимизатор ADAM.
ресурс:
Generator_optimizer = torch.optim.Adam(itertools.chain(Generator_A.parameters(), Generator_B.parameters()), betas=(parameters['beta1'], parameters['beta2']), lr=parameters['lgG']) Discriminator_A_optimizer = torch.optim.Adam(itertools.chain(Discriminator_A.parameters(), Discriminator_B.parameters()), betas=(parameters['beta1'], parameters['beta2']), lr=parameters['lgD']) Discriminator_B_optimizer = torch.optim.Adam(itertools.chain(Discriminator_A.parameters(), Discriminator_B.parameters()), betas=(parameters['beta1'], parameters['beta2']), lr=parameters['lgD'])
Определение потерь
MSE_Loss = torch.nn.MSELoss().cuda() L1_Loss = torch.nn.L1Loss().cuda()
Обучение модели и получение изображений
Discriminator_A_avg_losses = [] Discriminator_B_avg_losses = [] Generator_A_avg_losses = [] Generator_B_avg_losses = [] cycle_A_avg_losses = [] cycle_B_avg_losses = [] for epoch in range(parameters['num_epochs']): Discriminator_A_losses = [] Discriminator_B_losses = [] Generator_A_losses = [] Generator_B_losses = [] cycle_A_losses = [] cycle_B_losses = [] # Learing rate decay if(epoch + 1) > parameters['decay_epoch']: Discriminatoroptimizer.param_groups[0]['lr'] -= parameters['lgD'] / (parameters['num_epochs'] - parameters['decay_epoch']) Discriminator_B_optimizer.param_groups[0]['lr'] -= parameters['lgD'] / (parameters['num_epochs'] - parameters['decay_epoch']) G_optimizer.param_groups[0]['lr'] -= parameters['lrG'] / (parameters['num_epochs'] - parameters['decay_epoch']) # training for i, (real_A, real_B) in tqdm(enumerate(zip(train_data_loader_A, train_data_loader_B)), total=len(train_data_loader_A)): # input image data real_A = real_A.to(device) real_B = real_B.to(device) # Train The Generator # A --> B fake_B = Generator_A(real_A) Discriminator_B_fake_decision = Discriminator_B(fake_B) Generator_A_loss = MSE_Loss(Discriminator_B_fake_decision, Variable(torch.ones(Discriminator_B_fake_decision.size()).cuda())) # Forward Cycle Loss recon_A = Generator_B(fake_B) cycle_A_loss = L1_Loss(recon_A, real_A) * parameters['lambdaA'] # B --> A fake_A = Generator_B(real_B) Discriminator_A_fake_decision = Discriminator_A(fake_A) Generator_B_loss = MSE_Loss(Discriminator_A_fake_decision, Variable(torch.ones(Discriminator_A_fake_decision.size()).cuda())) # Backward Cycle Loss recon_B = Generator_A(fake_A) cycle_B_loss = L1_Loss(recon_B, real_B) * parameters['lambdaB'] # Back Propagation Generator_loss = Generator_A_loss + Generator_B_loss + cycle_A_loss + cycle_B_loss Generator_optimizer.zero_grad() Generator_loss.backward() Generator_optimizer.step() # Train Discriminator_A Discriminator_A_real_decision = Discriminator_A(real_A) Discriminator_A_real_loss = MSE_Loss(Discriminator_A_real_decision, Variable(torch.ones(Discriminator_A_real_decision.size()).cuda())) fake_A = fake_A_Pool.query(fake_A) Discriminator_A_fake_decision = Discriminator_A(fake_A) Discriminator_A_fake_loss = MSE_Loss(Discriminator_A_fake_decision, Variable(torch.zeros(Discriminator_A_fake_decision.size()).cuda())) # Back propagation Discriminator_A_loss = (Discriminator_A_real_loss + Discriminator_A_fake_loss) * 0.5 Discriminator_A_optimizer.zero_grad() Discriminator_A_loss.backward() Discriminator_A_optimizer.step() # Train Discriminator_B Discriminator_B_real_decision = Discriminator_B(real_B) # print('real_A, ',real_A.shape) # print('real_B, ',real_B.shape) # print('dis_A, ',Discriminator_A_real_decision.shape) # print('dis_B, ',Discriminator_B_real_decision.shape) Discriminator_B_real_loss = MSE_Loss(Discriminator_B_real_decision, Variable(torch.ones(Discriminator_B_fake_decision.size()).cuda())) #Discriminator_B_real_loss = MSE_Loss(Discriminator_B_real_decision, Variable(torch.ones(Discriminator_B_real_decision.size()).cuda())) fake_B = fake_B_Pool.query(fake_B) Discriminator_B_fake_decision = Discriminator_B(fake_B) Discriminator_B_fake_loss = MSE_Loss(Discriminator_B_fake_decision, Variable(torch.zeros(Discriminator_B_fake_decision.size()).cuda())) # Back Propagation Discriminator_B_loss = (Discriminator_B_real_loss + Discriminator_B_fake_loss) * 0.5 Discriminator_B_optimizer.zero_grad() Discriminator_B_loss.backward() Discriminator_B_optimizer.step() # Print # loss values Discriminator_A_losses.append(Discriminator_A_loss.item()) Discriminator_B_losses.append(Discriminator_B_loss.item()) Generator_A_losses.append(Generator_A_loss.item()) Generator_B_losses.append(Generator_B_loss.item()) cycle_A_losses.append(cycle_A_loss.item()) cycle_B_losses.append(cycle_B_loss.item()) if i%100 == 0: print('Epoch [%d/%d], Step [%d/%d], Discriminator_A_losses: %.4f, Discriminator_B_loss: %.4f, Generator_A_loss: %.4f, Generator_B_loss: %.4f' % (epoch+1, parameters['num_epochs'], i+1, len(train_data_loader_A), Discriminator_A_loss.item(), Discriminator_B_loss.item(), Generator_A_loss.item(), Generator_B_loss.item())) step += 1 Discriminator_A_avg_loss = torch.mean(torch.FloatTensor(Discriminator_A_losses)) Discriminator_B_avg_loss = torch.mean(torch.FloatTensor(Discriminator_B_losses)) Generator_A_avg_loss = torch.mean(torch.FloatTensor(Generator_A_losses)) Generator_B_avg_loss = torch.mean(torch.FloatTensor(Generator_B_losses)) cycle_A_avg_loss = torch.mean(torch.FloatTensor(cycle_A_losses)) cycle_B_avg_loss = torch.mean(torch.FloatTensor(cycle_B_losses)) # Average Loss Values Discriminator_A_avg_losses.append(Discriminator_A_avg_loss.item()) Discriminator_B_avg_losses.append(Discriminator_B_avg_loss.item()) Generator_A_avg_losses.append(Generator_A_avg_loss.item()) Generator_B_avg_losses.append(Generator_B_avg_loss.item()) cycle_A_avg_losses.append(cycle_A_avg_loss.item()) cycle_B_avg_losses.append(cycle_B_avg_loss.item()) # Test Image Results test_real_A = test_real_A_data.cuda() test_fake_B = Generator_A(test_real_A) test_recon_A = Generator_B(test_fake_B) test_real_B = test_real_B_data.cuda() test_fake_A = Generator_B(test_real_B) test_recon_B = Generator_A(test_fake_A) plot_train_result([test_real_A, test_real_B], [test_fake_B, test_fake_A], [test_recon_A, test_recon_B], epoch, save=True)
Результаты в разные эпохи
Восстановление убытков
all_losses = pd.DataFrame() all_losses['Discriminator_A_avg_losses'] = Discriminator_A_avg_losses all_losses['Discriminator_B_avg_losses'] = Discriminator_B_avg_losses all_losses['Generator_A_avg_losses'] = Generator_A_avg_losses all_losses['Generator_B_avg_losses'] = Generator_B_avg_losses all_losses['cycle_A_avg_losses'] = cycle_A_avg_losses all_losses['cycle_B_avg_losses'] = cycle_B_avg_losses
График потерь
plt.figure(figsize=(20,20)) losses.plot() plt.legend(bbox_to_anchor=(1, 1), loc='upper left', ncol=1) plt.show()
Вывод:
Хотя это и не идеально, мы можем видеть, что модель работает, визуально глядя на изменения в изображениях. Кроме того, мы также можем видеть уменьшение потерь по мере увеличения эпох. С дополнительными ресурсами мы также можем попытаться внести изменения в эпохи и скорость обучения, чтобы увидеть, какие значения хорошо работают для улучшения результатов.
Я надеюсь, что смог помочь в создании одной из многих GAN, которые вы намереваетесь создать.
Использованная литература:
https://github.com/ltq477/CycleGAN/blob/master/CycleGAN%20-%20Facial%20Gans%20ver_IV.ipynb
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
Учебник по Питорч:
https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
Пример реализации из учебника: