Вы когда-нибудь пробовали использовать функцию сборки pytorch? Я сделал это, и это было очень сложно. Сама функция довольно полезна, но понять, как ее использовать, может быть сложно.
Итак, какова цель функции сбора? Документы говорят:
torch.gather
(input, dim, index, out=None, sparse_grad=False) → Tensor
Gathers values along an axis specified by dim.
Итак, он собирает значения по оси. Но чем это отличается от обычной индексации? При использовании оператора []
вы выбираете один и тот же индекс во всех местах. Рассмотрим тензор 4x6 (4 - размер партии, 6 - характеристики). Когда вы делаете x[_,:]
или x[:, _]
, вы выбираете один и тот же индекс в каждом пакете / функции
Но представьте следующую ситуацию: вам нравится выбирать 3-ю функцию из 0-го примера, 7-ю функцию из 1-го примера, 4-ю из 3-го и 1-ю из 4-го.
Вы можете подумать о:
indices = torch.LongTensor([3,7,4,1]) x[:, indices]
Но вы получите:
tensor([[ 3, 7, 4, 1], [13, 17, 14, 11], [23, 27, 24, 21], [33, 37, 34, 31]])
Хорошо, нам нужна функция сбора.
Gather требует трех параметров:
- input - входной тензор
- dim - измерение вдоль для сбора значений
- index - тензор с индексами значений для сбора
Важное замечание: размерность input и index должна быть одинаковой, за исключением измерения dim. Например, если введено 4x10x15 и dim = 0, то индекс должен быть Nx10x15.
2D пример
Вернемся к нашему примеру. Мы знаем входной тензор. Измерение для сбора - 1 (потому что мы хотим варьировать индексы в измерении 1 - функции, а не примеры). Мы можем сделать вывод, что индекс должен быть 4x1. У нас есть четыре значения, которые нужно заполнить, и это будет просто:
indices = torch.LongTensor([3,7,4,1]) indices = indices.unsqueeze(-1) print(indices.shape) print(indices) > torch.Size([4, 1]) > tensor([[3], [7], [4], [1]])
Сжатие необходимо для расширения тензора с дополнительным измерением в конце (поэтому есть аргумент -1), преобразовывающего индексы из формы 4 в 4x1
Подключим и посмотрим, работает ли:
tensor([[ 3], [17], [24], [31]])
Да, работает!
Теперь упражнение для вас: с помощью сборки попробуйте извлечь следующие значения.
Пример 3D
В трех измерениях все становится сложнее.
Представьте, что у нас есть следующий сценарий: сеть RNN с последовательностями, дополненными до максимальной длины. Мы хотели бы собрать последний элемент в каждой последовательности со всеми функциями из скрытого состояния rnn.
Входные данные соответствуют тензору BATCH_SIZE x MAX_SEQ_LEN x HIDDEN_STATE. В нашем примере исправим это: размер партии = 8, max_seq_len = 9, hidden_size = 6.
batch_size = 8 max_seq_len = 9 hidden_size = 6 x = torch.empty(batch_size, max_seq_len, hidden_size) for i in range(batch_size): for j in range(max_seq_len): for k in range(hidden_size): x[i,j,k] = i + j*10 + k*100
Есть причина, по которой я создаю пример ввода таким образом - будет легко отследить, какие элементы мы на самом деле получаем. Например, значение «123» означает «1-й пакет, 2-й элемент последовательности, 3-е скрытое состояние». Если мы это сделаем:
x[:,4,:] >tensor([[ 40., 140., 240., 340., 440., 540.], [ 41., 141., 241., 341., 441., 541.], [ 42., 142., 242., 342., 442., 542.], [ 43., 143., 243., 343., 443., 543.], [ 44., 144., 244., 344., 444., 544.], [ 45., 145., 245., 345., 445., 545.], [ 46., 146., 246., 346., 446., 546.], [ 47., 147., 247., 347., 447., 547.]])
довольно легко отследить, что мы получили элементы из всех партий (все последние цифры находятся в диапазоне [0–7]), 4-го элемента последовательности (все числа в формате x4x) и всех скрытых состояний (все сотни цифр находятся внутри [0– 5]).
Забегая вперед, мы знаем индекс каждого последнего элемента последовательности (это может быть индекс последнего токена в предложении для задач НЛП)
lens = torch.LongTensor([5,6,1,8,3,7,3,4])
Теперь просто хочу извлечь значения скрытого состояния линзы из каждого примера.
Входной файл имеет форму размер_пакета x max_seq_len x hidden_state (8x9x6). Мы хотим собирать по измерению seq_len (1), поэтому форма индекса должна быть:
8x1x6
Итак, нам нужно заполнить 42 значения (8 * 6), но у нас есть 8 значений (по одному для каждого примера в пакете - индекс последнего элемента). Главное - понять, что у нас есть 6 скрытых состояний, и мы хотим собрать их все - ясно, что мы ожидаем 42 значения! (6 скрытых состояний для 8 примеров). Решение довольно простое, нам просто нужно повторить lens 6 раз:
lens = torch.LongTensor([5,6,1,8,3,7,3,4]) # add one trailing dimension lens = lens.unsqueeze(-1) print(lens.shape) > torch.Size([8, 1]) # repeat 6 times indices = lens.repeat(1,6) print(indices.shape) > torch.Size([8, 6]) print(indices) > tensor([[5, 5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], [1, 1, 1, 1, 1, 1], [8, 8, 8, 8, 8, 8], [3, 3, 3, 3, 3, 3], [7, 7, 7, 7, 7, 7], [3, 3, 3, 3, 3, 3], [4, 4, 4, 4, 4, 4]])
Продвигаясь вперед, добавьте «пустой» размер посередине (чтобы получилось 8x1x6).
indices = indices.unsqueeze(1) print(indices.shape) > torch.Size([8, 1, 6])
И примените это
results = torch.gather(x, 1, indices) print(results.shape) > torch.Size([8, 1, 6])
Форма правильная. Ценности тоже есть?
tensor([[[ 50., 150., 250., 350., 450., 550.]], [[ 61., 161., 261., 361., 461., 561.]], [[ 12., 112., 212., 312., 412., 512.]], [[ 83., 183., 283., 383., 483., 583.]], [[ 34., 134., 234., 334., 434., 534.]], [[ 75., 175., 275., 375., 475., 575.]], [[ 36., 136., 236., 336., 436., 536.]], [[ 47., 147., 247., 347., 447., 547.]]])
Мы видим, что он содержит все примеры из партии, в каждом примере есть все функции, и мы видим, что элементы последовательности соответственно 5,6,1,8,3,7,3,4
Что случилось?
Чтобы понять, что это означает, давайте разберемся с одним примером из партии.
indices[0,:] > tensor([5, 5, 5, 5, 5, 5])
Он содержит 6 элементов, соответствующее количеству скрытых состояний - это означает, что для каждого из скрытых состояний в 0-м примере выберите 5-й элемент по оси тусклого света. То же самое происходит со 2-м пакетом: тензор, содержащий [1,1,1,1,1,1], значение для каждого из 6 скрытых состояний, которые мы хотим получить из 1-й позиции предложения.
Чтобы было еще понятнее, я создал простую визуализацию. Настройка такая же, за исключением того, что теперь линза = [2,2,2,4,4,4,6,7]. Вы можете ясно видеть, как значения из индексов (записанных на нижней плоскости) соответствуют собранным значениям из входного тензора.
Заключение
Я надеюсь, что теперь вы понимаете, как использовать функцию сбора. Конечно, если у вас есть данные 4D, применима та же логика.