Опитвали ли сте някога да използвате функцията за събиране на pytorch? Направих го и беше много трудно. Самата функция е доста полезна, но разбирането как да я използвате може да бъде болезнено.

И така, каква е целта на функцията за събиране? Docs казва:

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.

Така че събира стойности по оста. Но как се различава от обикновеното индексиране? Когато използвате оператор [], избирате един и същ индекс на всяко място. Помислете за тензор 4x6 (4 е за размера на партидата, 6 е за характеристиките). Когато правите x[_,:] или x[:, _], вие избирате един и същ индекс във всяка партида/функция

Но си представете следната ситуация: искате да изберете 3-та функция от 0-ти пример, 7-ма характеристика от 1-ви пример, 4-та от 3-ти и 1-ва от 4-ти.

Може да се сетите за:

indices = torch.LongTensor([3,7,4,1])
x[:, indices]

Но ще получите:

tensor([[ 3,  7,  4,  1],
        [13, 17, 14, 11],
        [23, 27, 24, 21],
        [33, 37, 34, 31]])

Добре, имаме нужда от функция за събиране.

Gather изисква три параметъра:

  • вход — входен тензор
  • dim — измерение за събиране на стойности
  • индекс — тензор с индекси на стойностите за събиране

Важно съображение е, че размерността на входаи индексатрябва да бъде същата, освен в dimизмерението. Например, ако входът е 4x10x15 и dim=0, тогава индексът трябва да бъде Nx10x15.

2D пример

Връщайки се към нашия пример. Знаем входния тензор. Измерението за събиране е 1 (защото искаме да варираме индексите в измерение 1 — функции, не примери). Можем да заключим, че индексът трябва да бъде 4x1. Имаме четири стойности за попълване и това ще бъде просто:

indices = torch.LongTensor([3,7,4,1])
indices = indices.unsqueeze(-1)
print(indices.shape)
print(indices)
> torch.Size([4, 1])
> tensor([[3],
        [7],
        [4],
        [1]])

Unsqueeze е необходимо за разширяване на тензора с допълнително измерение в края (затова има аргумент -1), преобразуващо индексите от форма 4 в 4x1

Нека го включим и да видим дали работи:

tensor([[ 3],
        [17],
        [24],
        [31]])

Да работи!

Сега упражнете за вас: използвайки gather, опитайте се да извлечете следните стойности.

3D пример

В три измерения нещата стават по-трудни.

Представете си, че имаме следния сценарий: RNN мрежа с последователности, подплатени до максимална дължина. Бихме искали да съберем последния елемент във всяка последователност, с всички характеристики от скритото състояние на rnn.

Входните данни следват тензора BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STATE. За нашия пример нека го поправим на: размер на партидата = 8, max_seq_len = 9, hidden_size = 6

batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
  for j in range(max_seq_len):
    for k in range(hidden_size):
      x[i,j,k] = i + j*10 + k*100

Има причина да конструирам примерен вход по такъв начин — ще бъде лесно да проследим какви елементи всъщност получаваме. Например стойност „123“ означава „1-ва партида, 2-ри елемент от последователност, 3-то скрито състояние“. Ако го направим:

x[:,4,:]
>tensor([[ 40., 140., 240., 340., 440., 540.],
        [ 41., 141., 241., 341., 441., 541.],
        [ 42., 142., 242., 342., 442., 542.],
        [ 43., 143., 243., 343., 443., 543.],
        [ 44., 144., 244., 344., 444., 544.],
        [ 45., 145., 245., 345., 445., 545.],
        [ 46., 146., 246., 346., 446., 546.],
        [ 47., 147., 247., 347., 447., 547.]])

доста лесно е да се проследи, че имаме елементи от всички партиди (всички последни цифри са в диапазона [0–7]), елемент от 4-та последователност (всички числа са във формат x4x) и всички скрити състояния (всички стотици цифри са вътре в [0– 5]).

Занапред знаем индекса на всеки последен елемент от последователността (това може да бъде индексът на последния токен в изречението за NLP задачи)

lens = torch.LongTensor([5,6,1,8,3,7,3,4])

Сега просто искаме да извлечем стойностите на скритото състояние на лещата от всеки пример.

Входът е във формата batch_size x max_seq_len x hidden_state (8x9x6). Искаме да събираме по дължината на seq_len измерение (1), следователно формата на индекса трябва да бъде:

8x1x6

И така, трябва да попълним 42 стойности (8*6), но имаме 8 стойности (по една за всеки пример в партида — индекс на последния елемент). Ключът е да разберем, че имаме 6 скрити състояния и искаме да ги съберем всички — ясно е, че в очакваме 42 стойности! (6 скрити състояния за 8 примера). Решението е доста просто, просто трябва да повторим lens6 пъти:

lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)
> torch.Size([8, 1])
# repeat 6 times
indices = lens.repeat(1,6)
print(indices.shape)
> torch.Size([8, 6])
print(indices)
> tensor([[5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6],
        [1, 1, 1, 1, 1, 1],
        [8, 8, 8, 8, 8, 8],
        [3, 3, 3, 3, 3, 3],
        [7, 7, 7, 7, 7, 7],
        [3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4]])

Продължавайки напред, добавете „празно“ измерение в средата (за да стане 8x1x6)

indices = indices.unsqueeze(1)
print(indices.shape)
> torch.Size([8, 1, 6])

И го прилагайте

results = torch.gather(x, 1, indices)
print(results.shape)
> torch.Size([8, 1, 6])

Формата е правилна. Ценностите също ли са?

tensor([[[ 50., 150., 250., 350., 450., 550.]],
        [[ 61., 161., 261., 361., 461., 561.]],
        [[ 12., 112., 212., 312., 412., 512.]],
        [[ 83., 183., 283., 383., 483., 583.]],
        [[ 34., 134., 234., 334., 434., 534.]],
        [[ 75., 175., 275., 375., 475., 575.]],
        [[ 36., 136., 236., 336., 436., 536.]],
        [[ 47., 147., 247., 347., 447., 547.]]])

Виждаме, че съдържа всички примери от партида, във всеки пример има всички функции и виждаме, че елементите на последователността са съответно 5,6,1,8,3,7,3,4

Какво се случи?

За да разберем какво означава това, нека го разделим само на един пример от партида

indices[0,:]
> tensor([5, 5, 5, 5, 5, 5])

Той съдържа 6 елемента, съответстващи на броя скрити състояния — което означава, че за всяко от скритите състояния в 0-ия пример изберете 5-ти елемент по оста на дим. Същото се случва за 2-ра партида: тензор, съдържащ [1,1,1,1,1,1], което означава за всяко от 6-те скрити състояния, които бихме искали да вземем стойност от 1-ва позиция на изречението.

За да стане още по-ясно, създадох проста визуализация. Настройката е същата, но сега обектив = [2,2,2,4,4,4,6,7]. Можете ясно да видите как стойностите от индексите (записани в долната равнина) съответстват на събраните стойности от входния тензор.

Заключение

Надявам се, че сега разбирате как да използвате функцията за събиране. Разбира се, ако имате 4D данни, важи същата логика.