Вызов метода super forward ()

Каков наиболее подходящий способ вызова forward() метода родительского Module? Например, если я создаю подкласс модуля nn.Linear, я могу сделать следующее

class LinearWithOtherStuff(nn.Linear):
    def forward(self, x):
        y = super(Linear, self).forward(x)
        z = do_other_stuff(y)
        return z

Однако в документах не рекомендуется вызывать метод forward(). напрямую:

Хотя рецепт прямого прохода должен быть определен в этой функции, следует вызывать экземпляр модуля позже вместо этого, поскольку первый заботится о запуске зарегистрированных хуков, а второй молча игнорирует их.

что заставляет меня думать, что super(Linear, self).forward(x) может привести к неожиданным ошибкам. Это правда, или я неправильно понимаю наследование?


person dkv    schedule 18.02.2019    source источник
comment
Обратите внимание, что это больше вопрос семантики Module.forward, чем super.   -  person chepner    schedule 18.02.2019
comment
Строка документации для Module показывает пример подкласса, который вообще не вызывает Module.forward. Похоже, что Module.__call__ заботится о добавлении forward метода в список ловушек, которые вызываются автоматически, а не о том, что вам нужно вызывать явно самостоятельно.   -  person chepner    schedule 18.02.2019
comment
Есть ли особая причина использовать наследование вместо композиции?   -  person Querenker    schedule 16.03.2019
comment
например если do_other_stuff(y) изменяет переменную экземпляра LinearWithOtherStuff   -  person dkv    schedule 16.03.2019


Ответы (1)


TL; DR;

Вы можете свободно использовать super().forward(...) даже с крючками и даже с крючками, зарегистрированными в super() экземпляре.

Объяснение

Как указано этим ответ __call__ находится здесь, поэтому зарегистрированные хуки (например, _ 4_) будет запущен.

Если вы наследуете и хотите повторно использовать forward базового класса, например это:

import torch


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        return super(Child, self).forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `4`
module.register_forward_hook(lambda module, input, output: output + 1)
print(module(torch.tensor(1))) # and it is 4 indeed
print(module.forward(torch.tensor(1))) # here it is 3 still

У вас все в порядке, если вы вызываете __call__ метод, forward не будет запускать ловушку (так что вы получите 3, как указано выше).

Вряд ли вы захотите register_hook на экземпляре super, но давайте рассмотрим такой пример:

def increment_by_one(module, input, output):
    return output + 1


class Parent(torch.nn.Module):
    def forward(self, tensor):
        return tensor + 1


class Child(Parent):
    def forward(self, tensor):
        # Increment by `1` from Parent
        super().register_forward_hook(increment_by_one)
        return super().forward(tensor) + 1


module = Child()
# Increment output by 1 so we should get `5` in total
module.register_forward_hook(increment_by_one)
print(module(torch.tensor(1)))  # and it is 5 indeed
print(module.forward(torch.tensor(1)))  # here is 3

Вы прекрасно работаете с super().forward(...), и даже хуки будут работать правильно (и это основная идея использования __call__ вместо forward).

Кстати. Вызов super().__call__(...) вызовет InifiniteRecursion ошибку.

person Szymon Maszke    schedule 22.09.2020