WGAN е вид мрежа, използвана за генериране на фалшиви висококачествени изображения от входен вектор. В този експеримент внедрих две различни подобрения на WGAN в Pytorch, за да видя кой е в състояние да се представи най-добре по отношение на скоростта и качеството на генерираните изображения. (Github: https://github.com/BradleyBrown19/WGANOptimizations).

Какво е WGAN?

Предполагам, че имам предварителни познания за GAN, така че препоръчвам да прочетете следното, ако не сте запознати: https://towardsdatascience.com/image-generator-drawing-cartoons-with-generative-adversarial-networks-45e814ca9b6b

Във формулирането на WGAN има много сложна математика, но ще се опитам да я запазя проста и да я обясня интуитивно.

Основната концепция зад WGAN е да се опита да сведе до минимум разстоянието на земехода. Разстоянието на Earth Mover (уравнение 1) получава името си, защото може да се приеме за минималната работа, необходима за превръщането на една купчина пръст в друга купчина. С други думи, това е количеството мръсотия, умножено по изминатото разстояние.

В нашия случай купчините мръсотия са вероятностните разпределения на реалните изображения, Pr, и генерираните изображения, Pg.

Сега нека разбием горното уравнение. Π(pr,pg) означава вземане на всички съвместни разпределения между Pr и Pg и намира най-голямата долна граница (inf) на преместване на „мръсотията“ на разстояние между x и y („x−y‖), умножено по общото количество мръсотия Eγ(x,y). С други думи, ние се опитваме да намерим най-малкото количество работа или разходи, за да трансформираме вероятностните разпределения на генерираните изображения в реалните.

Това беше хапка.

Използвайки нещо, наречено двойственост на Канторович-Рубинщайн (вижте приложението към: https://arxiv.org/pdf/1701.07875.pdf), уравнение 1 може да се трансформира в:

където f е 1-липшиц, който трябва да следва ограничението:

Какво е f? Как да го намерим, за да максимизираме уравнение 2?

Drumroll моля... използваме невронна мрежа, т.е. критикът. Критикът е много подобен на дискриминатора в нормален GAN, но вместо да извежда вероятност изображението да е реално или фалшиво, той извежда скалар, представящ колко „реално“ е изображението. Единствената промяна в архитектурата между двете е, че новият критик пропуска сигмоидния слой, защото вече не е необходим. Критикът не може да покрие всички възможности на f за намиране на върховната сума на уравнения 2, но тъй като приближителната способност на невронната мрежа е огромна, тя осигурява много добра оценка.

Друга промяна с WGAN е, че функциите на журнала в загубите вече не са необходими.

Сега, когато имаме всичко, от което се нуждаем, е настроено, всичко, което остава, е да удовлетворим ограничението 1-Lipschitz. В нормален WGAN това се прави чрез просто изрязване на всички тегла в дискриминатора, като буквално се уверява, че са под абсолютна стойност. Тази стойност обикновено е 0,01.

Внедряване 1: WGAN с градиентно наказание и срок за съгласуваност

Какво е Gradient Penalty?

Големият недостатък на подстригването на теглото е, че то ограничава ефективността на модела. Ако теглата са ограничени твърде много, тогава моделът не може да се научи да моделира сложни функции и като такова приближението на f не е оптимално. От друга страна, ако тежестите не са достатъчно ограничени, това води до изчезващи градиенти. Като цяло оригиналният WGAN е твърде чувствителен към това изрязване, за да бъде възможно най-ефективен.

Въведете наказание за градиент,

Gradient Penalty е друг, по-малко ограничителен метод за налагане на ограничението 1-Lipschitz. По дефиницията на функция 1-Lipschitz (вижте уравнение 3), функцията удовлетворява това ограничение, ако максималната норма на градиентите е 1.

Това се прилага в загубата на дискриминатора чрез добавяне на допълнителен наказателен термин за градиент в загубата. Този термин се изчислява, като се вземе средната стойност на всички градиенти на критика, когато се подават смесица от реални и фалшиви данни и се вземе средната стойност на квадрата на разликата между всички градиенти и един. Хиперпараметър от 10 е стандартният множител за този член.

Този метод включва изчисляване на всички градиенти на дискриминатора на всяка стъпка от обучението и е много скъп от изчислителна гледна точка.

Какво е термин за съгласуваност?

Допълнителна регуляризация се добавя към невронната мрежа с въвеждането на термин за съгласуваност.

Това се изчислява чрез подаване на същите реални данни на критика два пъти с отпадане от около 0,5. Ние вземаме както крайния изход, така и активациите преди крайния изход.

Отпадането по същество изключва половината тегла в критика, което прави двата изхода различни един от друг, въпреки че са еднакви входни данни.

Терминът за последователност благоприятства критика да произведе същия скаларен изход въпреки разликата в отпадането. Това се прави, за да се повиши последователността и надеждността на изхода на критиката.

Реализация 2: WGAN със спектрална нормализация

Отново, за да обучим WGAN, трябва нашата функция да бъде непрекъсната на Lipshitz. Лесно е да се види от уравнението, че в 1-D случай това означава, че стойността на K трябва да бъде по-голяма от максималната стойност на производната на функцията. (За визуално доказателство за това вижте: christigancosgrove.com).

Тук се намесва спектралната нормализация. Тъй като знаем, че WGAN трябва да бъдат непрекъснати на Липшиц, трябва да намерим начин да ограничим градиентите на дискриминатора.

В случай на многомерна функция A: Rⁿ -› Rᵐ, спектралната норма се определя като най-голямата сингулярна стойност на A, която също е квадратен корен от най-голямата собствена стойност на AᵀA (вижте тук). Чрез куп линейна алгебра, която е перфектно обобщена „тук“). Можем да видим, че това също е константата на Липшиц на линейната функция.

Сега, когато виждаме, че константата на Липшиц за тази обща линейна и диференцируема функция е нейната спектрална норма на нейния градиент върху нейната област:

Сега трябва да намерим (и да съдържаме) константата на Липшиц за състава на функциите (дискриминатора), така че да е непрекъсната на Липшиц и да работи като WGAN.

Според верижното правило за композиция от функции:

Където членовете вдясно са просто градиенти на матрица, която се умножава заедно. Следователно можем да намерим спектралната норма на състава, като просто намерим спектралната норма на произведението на градиентите.

Горното може да се раздели на това окончателно уравнение:

Ако можем да фиксираме всяка от спектралните норми на линейните функции на 1, тогава композицията също ще бъде фиксирана на 1 и WGAN ще задоволи двойствеността на Кантарович-Рубенщайн.

Сега, последното нещо, което остава да направите, е да се уверите, че спектралната норма на всяка линейна функция е ограничена. Това може просто да се направи чрез изчисляване на следното: W/σ(W), където σ(W) е единствената най-голяма стойност на W (това е спектралната норма).

Това се прави с хитър трик, наречен мощност итерация. Което след известна манипулация се свежда до простото уравнение:

Където uᵀ и v са вектори съответно в кодомейна и домейна на W. Това е страхотно, защото трябва само да изчислим векторите u и v за всяко тегло на всяка стъпка от обучението. Това е много евтино от изчислителна гледна точка и прави това внедряване много по-бързо от WGAN-GP.

Ето страхотно алгоритмично изображение на внедряване на спектрална норма:

Последни мисли

Като цяло спектралната нормализация е най-добрият метод за оптимизиране на WGAN.

Във всяка епоха двата GAN бяха много сходни по отношение на производителността с много подобно качество на изображението във всяка епоха.

Въпреки това, за 4700 итерации, прилагането на спектралната норма отне само 25 минути, докато наказанието за градиент отне 3 часа и 35 минути. Това има смисъл поради голямото количество допълнителни режийни разходи, необходими при изпълнението на градиентното наказание.

Едно нещо, което трябва да се отбележи обаче е, че дискриминаторът в случая на изрязване на градиента беше маркер за качеството на генерираните изображения, докато не беше със спектрална нормализация. Числото е хубаво, защото е точен начин да разберете дали GAN все още се подобрява.