Potential solution to different forward for train and inference + IDE support for forward args

Hello! One thing that has somewhat frustrated me is the way forward works in nn.Module’s. Namely:

Here’s some simple code wraps nn.Module to allow multiple forwards and IDE support (also available with docstrings and imports in this gist)

def add_hooks(new_forward_func):
    def wrapper(self: MultiforwardTorchModule, *args, **kwargs):
        return self(new_forward_func, self, *args, **kwargs)
    return wrapper

class MultiforwardTorchModule(torch.nn.Module):
    def forward(self, actual_forward: Callable, *args, **kwargs):
        return actual_forward(*args, **kwargs)

Then you can do something like

class MyModule(MultiforwardTorchModule):
       @add_hooks
        def forward_train(
            hidden_state: torch.Tensor,
            teacher_force_seq: List[str]
            # training specific args...
        ) -> torch.Tensor:
            # ....
            return loss

        @add_hooks
        def forward_inference(
            hidden_state: torch.Tensor,
            beam_size: int
            # inference specific args...
        ) -> List[str]:
            # ....
            return result

    # mod = MyModule()
    # instead of this: mod(foo, bar)
    # we can do this: mod.forward_train(foo, bar)
    # and still have the forward/backwards hooks called

Sharing this in the hope this is useful to someone.

Also, if anyone with better understanding of pytorch internals have any insight on whether this a terrible idea or not, that would be great. To be honest I am not really sure what forward/backwards hooks are used for, so do not fully understand why it is important they are called.

2 Likes