Hello! One thing that has somewhat frustrated me is the way forward works in nn.Module’s. Namely:
- If a module takes in different args in training and inference, you have to just make one big forwards with a combination of the args
- IDE’s are not able to provide code completion / static analysis based off the forward signature.
Others seem to have this issue as well, without a particularly clean solution [Any different between model(input) and model.forward(input) , Can forward() in nn.module be override with different arguments? ]
Here’s some simple code wraps nn.Module to allow multiple forwards and IDE support (also available with docstrings and imports in this gist)
def add_hooks(new_forward_func): def wrapper(self: MultiforwardTorchModule, *args, **kwargs): return self(new_forward_func, self, *args, **kwargs) return wrapper class MultiforwardTorchModule(torch.nn.Module): def forward(self, actual_forward: Callable, *args, **kwargs): return actual_forward(*args, **kwargs)
Then you can do something like
class MyModule(MultiforwardTorchModule): @add_hooks def forward_train( hidden_state: torch.Tensor, teacher_force_seq: List[str] # training specific args... ) -> torch.Tensor: # .... return loss @add_hooks def forward_inference( hidden_state: torch.Tensor, beam_size: int # inference specific args... ) -> List[str]: # .... return result # mod = MyModule() # instead of this: mod(foo, bar) # we can do this: mod.forward_train(foo, bar) # and still have the forward/backwards hooks called
Sharing this in the hope this is useful to someone.
Also, if anyone with better understanding of pytorch internals have any insight on whether this a terrible idea or not, that would be great. To be honest I am not really sure what forward/backwards hooks are used for, so do not fully understand why it is important they are called.