consider wrappers such as nn.DataParallel. Ignoring the other specifics of what that class is doing, the basic pattern is that you have a class which simply wraps the function of a base module:
class Wrapper(nn.Module):
def __init__(self, base_module: nn.Module):
super().__init__()
self.base_module = base_module
the real trouble comes with generically wrapping a forward pass:
def forward(self, *args, **kwargs):
self.base_module(*args, **kwargs)
I would love to see this wrapper itself be jit scriptable so long as the base_module is. In my particular use case, I actually don’t need to add anything to this forward pass. I know that I may simply have to require the forward method to only take a single argument, but I would like to see if anyone has handled this in other ways.