Dynamo Trace with Parameter Lifting

Hi,

During the Dynamo Trace phase, is there a suitable way to lift all the parameters in the Module to input arguments? Our requirement is: for each subsequent inference, feeding the model parameters as function arguments (rather than functionalize them during the aot_autograd phase, allowing the backend compiler to see a functional graph).

Why do this?
Our scenario is: the online model’s parameters updated every 10-30 minutes.

Why not use torch.export?
Unfortunately, our model cannot always be traced into a single graph.

Why not recompile?
Unfortunately, our users’ models require a lengthy period for tracing and compiling, which is even longer than the interval between two rounds of parameter update.

What we tried and why not work?
We tried to create a new callable, update parameters and then call the original forward method. A demo is shown below:

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.kernel = torch.nn.Parameter(torch.empty(256, 4).to(device="cuda:0"))
        self.relu = torch.nn.ReLU()

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = torch.matmul(x, self.kernel)
        b = self.relu(a)
        return b

def create_call_func(m):
    orig_forward = m.forward
    def call_func(x, kernel):
        m._parameters["kernel"].data = kernel
        orig_forward(x)
    return call_func

We’ve encountered quite a few issues, including:

  1. During the first compilation trace, the parameters passed as input arguments need to have the same ID (id(variable)) as the parameters in the model.
  2. Failures of guards.
  3. Handling of side effects.

This makes us question whether our current approach is incorrect.

Alternatives
Shall we just separate the logic for updating parameters from the forward function? Just use the relevant methods from torch.nn.Module to replace the parameters.

Looking forward to any suggestions

Our scenario is: the online model’s parameters updated every 10-30 minutes.

Can you describe in a bit more detail what this means? are you updating the data of your model’s parameters? Or are you taking your nn.Module and slamming in fresh nn.Parameters, and you would like to avoid recompiles even though the parameters have changed). This sounds like something that should work (you can avoid the recompiles as long as your parameter shapes aren’t changing), although some code examples would be helpful