Recommendations for how to best handle a generic module wrapper with jit.script?

consider wrappers such as nn.DataParallel. Ignoring the other specifics of what that class is doing, the basic pattern is that you have a class which simply wraps the function of a base module:

class Wrapper(nn.Module):
    def __init__(self, base_module: nn.Module):
        super().__init__()
        self.base_module = base_module

the real trouble comes with generically wrapping a forward pass:

    def forward(self, *args, **kwargs):
        self.base_module(*args, **kwargs)

I would love to see this wrapper itself be jit scriptable so long as the base_module is. In my particular use case, I actually don’t need to add anything to this forward pass. I know that I may simply have to require the forward method to only take a single argument, but I would like to see if anyone has handled this in other ways.

Depends on why you want to have such a forwarder compiled. But basically you’ll have to either avoid python’s star syntax (there are multiple jit-compatible ways to group arguments, but they’re intrusive wrt base_module - i.e. all wrapped modules need to implement a special interface) or have some non-compiled code in your model.