Hello! One thing that has somewhat frustrated me is the way forward works in nn.Module’s. Namely:
- If a module takes in different args in training and inference, you have to just make one big forwards with a combination of the args
- IDE’s are not able to provide code completion / static analysis based off the forward signature.
Others seem to have this issue as well, without a particularly clean solution [Any different between model(input) and model.forward(input) , Can forward() in nn.module be override with different arguments? ]
Here’s some simple code wraps nn.Module to allow multiple forwards and IDE support (also available with docstrings and imports in this gist)
def add_hooks(new_forward_func):
def wrapper(self: MultiforwardTorchModule, *args, **kwargs):
return self(new_forward_func, self, *args, **kwargs)
return wrapper
class MultiforwardTorchModule(torch.nn.Module):
def forward(self, actual_forward: Callable, *args, **kwargs):
return actual_forward(*args, **kwargs)
Then you can do something like
class MyModule(MultiforwardTorchModule):
@add_hooks
def forward_train(
hidden_state: torch.Tensor,
teacher_force_seq: List[str]
# training specific args...
) -> torch.Tensor:
# ....
return loss
@add_hooks
def forward_inference(
hidden_state: torch.Tensor,
beam_size: int
# inference specific args...
) -> List[str]:
# ....
return result
# mod = MyModule()
# instead of this: mod(foo, bar)
# we can do this: mod.forward_train(foo, bar)
# and still have the forward/backwards hooks called
Sharing this in the hope this is useful to someone.
Also, if anyone with better understanding of pytorch internals have any insight on whether this a terrible idea or not, that would be great. To be honest I am not really sure what forward/backwards hooks are used for, so do not fully understand why it is important they are called.