Надеюсь, таких программистов, как я, меньше.

Я думаю, что torch.scatter и torch.gather — два самых сложных тензорных API в пакете PyTorch. Официальные документы умного объяснения не дают, а лишь предлагают эквивалентный код:

# scatter
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
# gather
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

Однако это не совсем интуитивно понятно.

Наивный пример

torch.Tensor.scatter имеет 4 параметра

(dim, index, src, reduce=None)

Сначала игнорируйте reduce, я объясню это в конце. Начнем с самых простых. У нас есть тензор формы (2, 4), заполненный единицами:

Затем мы вызываем его метод scatter и передаем dim, index и src следующим образом:

а затем получить вывод

Чтобы понять, как это работает, давайте увеличим index и посмотрим на результаты.

Посмотрите внимательно здесь, мы можем заметить, что

  1. старое значение вывода, т. е. 1, получено из входного тензора
    — на выходе по-прежнему семь единиц
  2. новоезначение вывода, т. е. 5, получено из src
    — вы видите 5 в выводе
    — 5 тесно связано с индекс
  3. 5 сдвигается в зависимости от значения в индексе
    - 0~3
    меняет положение 5
  4. значение внутри index будет использоваться для разделения входного тензора и его обновления
    - 4 вызовет ошибку выхода за границы
    - dim Параметр означает, по какому измерению следует разрезать

Краткое заключение

  1. индекс позволит PyTorch извлечь 5 из src в качестве возможного значения и использовать его для обновления входного тензора.
  2. Значение внутри индекса указывает PyTorch явную позицию для обновления.

Большой! Мы понимаем основы, но все еще недостаточно ясно.

Как индекс управляет извлечением из src?

Извлечение из источника

Теперь оставьте все как есть, кроме структуры индекса.

Хм… странно, да? Не волнуйтесь, давайте увеличим второе значение внутри нового индекса:

Итак, новый индекс дает такое же поведение, как и предыдущий, но как первый вывод дает нам только одну 6? Куда делся 5?

Вот что происходит под капотом:

индексструктурафактически расширяется до той же структуры, что и src.

Если внутри index есть значение, извлеките значение из src в той же позиции. Если значения нет, ничего не делайте. Здесь у нас есть 0 и 3, поэтому извлеките 5 и 6, так как 5 и 6 имеют ту же позицию, что и расширенный индекс.

Это означает, что структура индекса должна быть подструктурой src. В противном случае вы получите ошибку:

Теперь вернемся к исходному примеру, мы извлекли 5 и 6. Куда поставить 5? 0; Куда поставить 6? 3. 0 и 3 — это значения внутри индекса.

тензор имеет два измерения, какое из них мы должны разрезать вдоль? В нашем случае dim=1.

Вдоль dim=1 означает, что индекс будет указывать столбец в исходном тензоре. Позиция строки останется такой же, как 5 и 6 лежат в src (в нашем случае, строка 0).

Итак, вы можете себе представить, почему мы получаем только одну 6, когда индекс равен [[0,0]]. Ячейка обновлялась дважды, с 1 на 5 и с 5 на 6.

Что, если dim=0? Давай попробуем.

Вдоль dim=0 означает, что индекс будет указывать строку в исходном тензоре. Позиция столбца останется такой же, как 5 и 6 лежат в src (в нашем случае столбец 0 и столбец 1).

Краткое заключение

  1. index должен иметь подструктуру структуры src
  2. dimуказывает, где разрезать входной тензор

Более интуитивно вы можете рассматривать index как «нескользящее» окно src. Посмотрите на следующую диаграмму, стрелка извлечения объясняет, как это работает.

Обратное мышление

Теперь подумайте наоборот, у нас есть только вход и выход, как нам установить параметры, чтобы получить желаемый результат?

Вот наш входной тензор

Мы хотим получить следующий результат, вызвав метод разброса:

Если dim=0, мы будем нарезать вдоль строки

Две шестерки в выходных данных поступают из src, они извлекаются на основе непустого значения внутри index. Следовательно, src должен содержать как минимум два 6

а индекс должен быть не менее

Мы не можем оставить ? внутри index пусто, там должно быть значение. Сейчас у нас есть только два варианта: 0 и 1, потому что входной тензор имеет только две строки. Если мы поставим 1, ? в src изменит значение строки во входных данных.

Если мы поставим 0, ? в src изменит значение нулевой строки во входных данных.

В обоих случаях ? в src должно быть 1, иначе мы не получим желаемого вывода.

Что, если dim=1с тем же входом и тем же выходом?

Во-первых, у нас разные направления работы.

Мы разрезаем столбец, поэтому значение внутри индекса может варьироваться от 0 до 3. Мы хотим изменить входной столбец 1 и столбец 3 на 6, поэтому индекс и источник должен быть

Смотреть! Как это элегантно! По сравнению со случаем dim=0 это гораздо лучший выбор. Нам не нужно заполнять 1 в src. Обычно плохо выбирать метод, который позволяет нашему параметру зависеть от выходного значения, dim=1 не даст использовать это условие.

Напомним, что связь между index и src — это извлечение. PyTorch будет извлекать значения из src на основе структуры index. Это означает, что размер src не имеет значения, пока PyTorch может извлекать из него ценность. То есть следующие варианты также работают

Мы получим ошибки, только если индекс слишком велик.

Ознакомьтесь с официальным документом

Вот код объяснения в официальном документе.

# scatter
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

Давайте разберем его, раскрасив

Это трехмерный случай, давайте изменим его на двухмерный.

  1. Сине-зеленый цвет указывает на взаимосвязь между index и src.
    - iи jдолжны быть допустимы как для index, так и для src, поэтому index должен быть подструктурой структуры src
    — насколько велик src не имеет значения, пока PyTorch может извлечь из него ценность
  2. Бело-голубой цвет означает, что значение внутри index не может превышать размер измерения, указанного dim.

Уменьшать

Наконец, последняя часть. Параметр reduce имеет три варианта: «Нет», «добавить» или «умножить». Если «добавить», назначение значения станет «добавить и заменить»; если «умножить», присвоение значения станет «умножить и заменить». Довольно просто понять, не так ли?

Заключение

Я понятия не имею, почему я хочу написать этот учебник, так устал.