Вы когда-нибудь пробовали использовать функцию сборки pytorch? Я сделал это, и это было очень сложно. Сама функция довольно полезна, но понять, как ее использовать, может быть сложно.

Итак, какова цель функции сбора? Документы говорят:

torch.gather(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.

Итак, он собирает значения по оси. Но чем это отличается от обычной индексации? При использовании оператора [] вы выбираете один и тот же индекс во всех местах. Рассмотрим тензор 4x6 (4 - размер партии, 6 - характеристики). Когда вы делаете x[_,:] или x[:, _], вы выбираете один и тот же индекс в каждом пакете / функции

Но представьте следующую ситуацию: вам нравится выбирать 3-ю функцию из 0-го примера, 7-ю функцию из 1-го примера, 4-ю из 3-го и 1-ю из 4-го.

Вы можете подумать о:

indices = torch.LongTensor([3,7,4,1])
x[:, indices]

Но вы получите:

tensor([[ 3,  7,  4,  1],
        [13, 17, 14, 11],
        [23, 27, 24, 21],
        [33, 37, 34, 31]])

Хорошо, нам нужна функция сбора.

Gather требует трех параметров:

  • input - входной тензор
  • dim - измерение вдоль для сбора значений
  • index - тензор с индексами значений для сбора

Важное замечание: размерность input и index должна быть одинаковой, за исключением измерения dim. Например, если введено 4x10x15 и dim = 0, то индекс должен быть Nx10x15.

2D пример

Вернемся к нашему примеру. Мы знаем входной тензор. Измерение для сбора - 1 (потому что мы хотим варьировать индексы в измерении 1 - функции, а не примеры). Мы можем сделать вывод, что индекс должен быть 4x1. У нас есть четыре значения, которые нужно заполнить, и это будет просто:

indices = torch.LongTensor([3,7,4,1])
indices = indices.unsqueeze(-1)
print(indices.shape)
print(indices)
> torch.Size([4, 1])
> tensor([[3],
        [7],
        [4],
        [1]])

Сжатие необходимо для расширения тензора с дополнительным измерением в конце (поэтому есть аргумент -1), преобразовывающего индексы из формы 4 в 4x1

Подключим и посмотрим, работает ли:

tensor([[ 3],
        [17],
        [24],
        [31]])

Да, работает!

Теперь упражнение для вас: с помощью сборки попробуйте извлечь следующие значения.

Пример 3D

В трех измерениях все становится сложнее.

Представьте, что у нас есть следующий сценарий: сеть RNN с последовательностями, дополненными до максимальной длины. Мы хотели бы собрать последний элемент в каждой последовательности со всеми функциями из скрытого состояния rnn.

Входные данные соответствуют тензору BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STATE. В нашем примере исправим это: размер партии = 8, max_seq_len = 9, hidden_size = 6.

batch_size = 8
max_seq_len = 9
hidden_size = 6
x = torch.empty(batch_size, max_seq_len, hidden_size)
for i in range(batch_size):
  for j in range(max_seq_len):
    for k in range(hidden_size):
      x[i,j,k] = i + j*10 + k*100

Есть причина, по которой я создаю пример ввода таким образом - будет легко отследить, какие элементы мы на самом деле получаем. Например, значение «123» означает «1-й пакет, 2-й элемент последовательности, 3-е скрытое состояние». Если мы это сделаем:

x[:,4,:]
>tensor([[ 40., 140., 240., 340., 440., 540.],
        [ 41., 141., 241., 341., 441., 541.],
        [ 42., 142., 242., 342., 442., 542.],
        [ 43., 143., 243., 343., 443., 543.],
        [ 44., 144., 244., 344., 444., 544.],
        [ 45., 145., 245., 345., 445., 545.],
        [ 46., 146., 246., 346., 446., 546.],
        [ 47., 147., 247., 347., 447., 547.]])

довольно легко отследить, что мы получили элементы из всех партий (все последние цифры находятся в диапазоне [0–7]), 4-го элемента последовательности (все числа в формате x4x) и всех скрытых состояний (все сотни цифр находятся внутри [0– 5]).

Забегая вперед, мы знаем индекс каждого последнего элемента последовательности (это может быть индекс последнего токена в предложении для задач НЛП)

lens = torch.LongTensor([5,6,1,8,3,7,3,4])

Теперь просто хочу извлечь значения скрытого состояния линзы из каждого примера.

Входной файл имеет форму размер_пакета x max_seq_len x hidden_state (8x9x6). Мы хотим собирать по измерению seq_len (1), поэтому форма индекса должна быть:

8x1x6

Итак, нам нужно заполнить 42 значения (8 * 6), но у нас есть 8 значений (по одному для каждого примера в пакете - индекс последнего элемента). Главное - понять, что у нас есть 6 скрытых состояний, и мы хотим собрать их все - ясно, что мы ожидаем 42 значения! (6 скрытых состояний для 8 примеров). Решение довольно простое, нам просто нужно повторить lens 6 раз:

lens = torch.LongTensor([5,6,1,8,3,7,3,4])
# add one trailing dimension
lens = lens.unsqueeze(-1)
print(lens.shape)
> torch.Size([8, 1])
# repeat 6 times
indices = lens.repeat(1,6)
print(indices.shape)
> torch.Size([8, 6])
print(indices)
> tensor([[5, 5, 5, 5, 5, 5],
        [6, 6, 6, 6, 6, 6],
        [1, 1, 1, 1, 1, 1],
        [8, 8, 8, 8, 8, 8],
        [3, 3, 3, 3, 3, 3],
        [7, 7, 7, 7, 7, 7],
        [3, 3, 3, 3, 3, 3],
        [4, 4, 4, 4, 4, 4]])

Продвигаясь вперед, добавьте «пустой» размер посередине (чтобы получилось 8x1x6).

indices = indices.unsqueeze(1)
print(indices.shape)
> torch.Size([8, 1, 6])

И примените это

results = torch.gather(x, 1, indices)
print(results.shape)
> torch.Size([8, 1, 6])

Форма правильная. Ценности тоже есть?

tensor([[[ 50., 150., 250., 350., 450., 550.]],
        [[ 61., 161., 261., 361., 461., 561.]],
        [[ 12., 112., 212., 312., 412., 512.]],
        [[ 83., 183., 283., 383., 483., 583.]],
        [[ 34., 134., 234., 334., 434., 534.]],
        [[ 75., 175., 275., 375., 475., 575.]],
        [[ 36., 136., 236., 336., 436., 536.]],
        [[ 47., 147., 247., 347., 447., 547.]]])

Мы видим, что он содержит все примеры из партии, в каждом примере есть все функции, и мы видим, что элементы последовательности соответственно 5,6,1,8,3,7,3,4

Что случилось?

Чтобы понять, что это означает, давайте разберемся с одним примером из партии.

indices[0,:]
> tensor([5, 5, 5, 5, 5, 5])

Он содержит 6 элементов, соответствующее количеству скрытых состояний - это означает, что для каждого из скрытых состояний в 0-м примере выберите 5-й элемент по оси тусклого света. То же самое происходит со 2-м пакетом: тензор, содержащий [1,1,1,1,1,1], значение для каждого из 6 скрытых состояний, которые мы хотим получить из 1-й позиции предложения.

Чтобы было еще понятнее, я создал простую визуализацию. Настройка такая же, за исключением того, что теперь линза = [2,2,2,4,4,4,6,7]. Вы можете ясно видеть, как значения из индексов (записанных на нижней плоскости) соответствуют собранным значениям из входного тензора.

Заключение

Я надеюсь, что теперь вы понимаете, как использовать функцию сбора. Конечно, если у вас есть данные 4D, применима та же логика.