I’d like to update BackPACK from register_backward_hook
to register_full_backward_hook
. However, this changes the forward hook registered with register_forward_hook
.
Expected/old behaviour: In the forward hook, the argument input[0]
of the succeeding module is identical with the argument output
of the preceding module.
New behaviour: These objects are different now. This can be seen from the ids in the minimum example.
What I am looking for: Is this the intended behaviour?
If no, will it be changed to the old behaviour?
If yes, how can I achieve the following:
Based on the minimum example, I calculate a quantity a_loss
in the backward hook on _loss
. How can I access the quantity a_loss
in the backward hook on _linear2
to calculate a_linear2
, and so on?
If it is not clear what I mean, I can follow up with a separate example.
Minimum example
from typing import List, Tuple
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss, Linear, Module, Sequential
def _forward_hook(module: Module, input: Tuple[Tensor], output: Tensor) -> None:
print(
f"{type(module)}: id of input, input[0], and output: "
f"{id(input)}, {id(input[0])}, {id(output)}"
)
def _backward_hook(module: Module, g_inp: Tuple[Tensor], g_out: Tuple[Tensor]) -> None:
pass
def _conduct_experiment(installed_hook: str) -> None:
batch_size: int = 2
num_class: int = 3
_module: Sequential = Sequential(
Linear(in_features=num_class, out_features=num_class),
Linear(in_features=num_class, out_features=num_class),
)
_input: Tensor = torch.rand(batch_size, num_class, requires_grad=True)
_target: Tensor = torch.empty(batch_size, dtype=torch.long).random_(num_class)
_loss_module: CrossEntropyLoss = CrossEntropyLoss()
list_modules: List[Module] = list(_module.modules()) + [_loss_module]
print(f"Register hooks with {installed_hook}.")
for module_child in list_modules:
module_child.register_forward_hook(_forward_hook)
getattr(module_child, installed_hook)(_backward_hook)
print("Forward pass.")
_loss_module(_module(_input), _target)
_conduct_experiment("register_backward_hook")
print("\n\n")
_conduct_experiment("register_full_backward_hook")
Output
Register hooks with register_backward_hook.
Forward pass.
<class 'torch.nn.modules.linear.Linear'>: id of input, input[0], and output: 140708990519632, 140707567033312, 140708989826944
<class 'torch.nn.modules.linear.Linear'>: id of input, input[0], and output: 140708990519632, 140708989826944, 140708989843568
<class 'torch.nn.modules.container.Sequential'>: id of input, input[0], and output: 140708989964432, 140707567033312, 140708989843568
<class 'torch.nn.modules.loss.CrossEntropyLoss'>: id of input, input[0], and output: 140708990430336, 140708989843568, 140707010881248
Register hooks with register_full_backward_hook.
Forward pass.
<class 'torch.nn.modules.linear.Linear'>: id of input, input[0], and output: 140708990389904, 140707010514688, 140707010514848
<class 'torch.nn.modules.linear.Linear'>: id of input, input[0], and output: 140708990443088, 140707010514848, 140707010531728
<class 'torch.nn.modules.container.Sequential'>: id of input, input[0], and output: 140708989964432, 140707010514448, 140707010531808
<class 'torch.nn.modules.loss.CrossEntropyLoss'>: id of input, input[0], and output: 140708989841728, 140707010531808, 140707010532128