Дано има по-малко програмисти, които се борят като мен.

Мисля, че torch.scatter и torch.gather са двата най-трудни API за тензор в пакета PyTorch. „Официалните документи“ не ни дават умно обяснение, те предлагат само еквивалентен код:

# scatter
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
# gather
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Това обаче не е съвсем интуитивно.

Наивният пример

torch.Tensor.scatter има 4 параметъра

(dim, index, src, reduce=None)

Първо игнорирайте намаляване, ще го обясня накрая. Да започнем с най-простите. Имаме тензор с форма (2, 4), пълен с единици:

След това извикваме неговия метод scatter и предаваме dim, index и src както следва

и след това вземете изхода

За да разберете как работи, нека увеличим индекса и да видим резултатите

Погледнете внимателно тук, можем да го забележим

  1. старата стойност на изхода, т.е. 1, идва от входния тензор
    - все още има седем в изхода
  2. новастойност на изхода, т.е. 5, идва от src
    - виждате 5 в изхода
    - 5 е силно свързано с индекс
  3. 5 се измества въз основа на стойността в индекса
    - 0~3
    промяна на позицията на 5
  4. стойност вътре в индекс ще се използва за разделяне на входния тензор и актуализирането му
    - 4 ще доведе до грешка извън границите
    - dim параметърът означава по кое измерение да се нарязва

Кратко заключение

  1. Индексътще позволи на PyTorch да извлече 5 от srcкато кандидат стойност и да го използва за актуализиране на входния тензор.
  2. Стойността вътре в индексаказва на PyTorch изричната позиция за актуализиране.

Страхотен! Разбираме основите, но все още не сме достатъчно ясни.

Как индекс контролира извличането от src?

Извличане от src

Сега оставете всичко същото, с изключение на структурата на индекса.

Хм… странно, нали? Не се притеснявайте, нека увеличим втората стойност в новия индекс:

Добре, така че новият индекс дава подобно поведение като предишния, но как първият изход ни дава само едно 6? Къде отива 5?

Ето какво се случва под капака:

Индекснатаструктуравсъщност е разширена до същата структура като src.

Ако има стойност вътре в index, извлечете стойността от src на същата позиция. Ако няма стойност, не правете нищо. Тук имаме 0 и 3, така че извлечете 5 и 6, тъй като 5 и 6 имат същата позиция като разширения индекс.

Това означава, че структурата на индекса трябва да бъде подструктура на src. В противен случай ще получите грешка:

Сега се върнете към оригиналния пример, извлякохме 5 и 6. Къде да поставим 5? 0; Къде да сложа 6? 3. 0 и 3 са стойностите в индекса.

Тензорът има две измерения, кое трябва да разделим по? В нашия случай, dim=1.

За dim=1означава, че индекс ще посочи колоната в оригиналния тензор. Позицията на реда ще остане същата като там, където се намират 5 и 6 в src(в нашия случай ред 0).

И така, можете да си представите защо получаваме само едно 6, когато индекс е [[0,0]]. Клетката беше актуализирана два пъти, от 1 до 5 и от 5 до 6.

Ами ако dim=0? Да пробваме.

За dim=0 означава, че индексътще посочи реда в оригиналния тензор. Позицията на колоната ще остане същата като там, където се намират 5 и 6 в src(в нашия случай колона 0 и колона 1).

Кратко заключение

  1. indexтрябва да има подструктура на структурата на src
  2. dimпоказва къде да се нарязва по входния тензор

По-интуитивно можете да разглеждате индекса като „неплъзгащ се“ прозорец на src. Погледнете следната диаграма, стрелката за извличане обяснява как работи.

Обратно мислене

Сега помислете обратното, имаме само вход и изход, как да зададем параметрите, за да получим желания резултат?

Ето нашия входен тензор

Искаме следния изход чрез извикване на scatter метод,

Ако dim=0, ще изрежем по посока на реда

Двете 6s в изхода идват от src, те се извличат въз основа на непразната стойност в индекс. Следователно src трябва да съдържа поне две 6

и индексът трябва да бъде поне

Не можем да оставим? вътре индекс е празен, те трябва да са стойност там. В момента имаме само две опции: 0 и 1, тъй като входният тензорима само два реда. Ако поставим 1, ? в src ще промени стойността на ред едно във входа.

Ако поставим 0, ? в srcще промени стойността на нулевия ред във входа.

И в двата случая ? в src трябва да бъде 1, в противен случай няма да получим желания резултат.

Ами ако dim=1със същия вход, същия изход?

Първо, имаме различна посока на работа.

Разрязваме колоната, така че стойността вътре в индекса може да варира от 0 до 3. Искаме да променим колона 1 на входа и колона 3 на 6, така че индексъти src трябва да бъде

Виж! Колко е елегантно! В сравнение със случая dim=0, това е много по-добър избор. Не е необходимо да попълваме 1 в src. Обикновено е лошо да избираме метод, който позволява нашия параметър да зависи от изходната стойност, dim=1 няма да даде възможност за използване на това условие.

Спомнете си, че връзката между index и src е извличане. PyTorch ще извлече стойности от src въз основа на структурата на индекс. Това означава, че колко голям src няма значение, стига PyTorch да може да извлича стойност от него. Тоест, следните възможности за избор също работят

Ще получаваме грешки само когато индексъте твърде голям.

Разберете официалния документ

Ето кода на обяснението в официалния документ.

# scatter
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

Нека го разбием чрез оцветяване

Това е 3-измерен случай, нека го променим на 2-измерен.

  1. Синьо-зеленото показва връзката между индекс и src.
    - iи jтрябва да са приложими както за индекс, така и за srcформа, така че индексъттрябва да бъде подструктура на структурата на src
    - колко голям от src няма значение, стига PyTorch да може да извлече стойност от него
  2. Синьо-бялото означава, че стойността вътре в indexне може да надвишава размера на измерението, указано от dim.

Намалете

И накрая, последната част. Параметърът reduce има три възможности за избор: Няма, „добавяне“ или „умножаване“. Ако „добави“, присвоената стойност ще стане „добавяне и замяна“; ако „умножи“, присвоената стойност ще стане „умножи и замени“. Доста лесно за разбиране, нали?

Заключение

Нямам представа защо искам да напиша този урок, толкова съм уморен.