How to script the model and lazily load parameters at the same time?

Hi, I am working on a tool to make model only load parameters when needed to reduce peak memory (link). And I would like to take advantage of the torch.jit optimization but it failed. Here is a reproducible code:

import functools
import torch
import torch.nn as nn
import torch.nn.functional as F  # for torch.jit.script


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32, 32)
        self.fc2 = nn.Linear(32, 32)

    def forward(self, input):
        out = self.fc1(input)
        out = self.fc2(out)
        return out


def release_weights(module, _input=None, _output=None):
    module.to('cpu')


def load_weights(module, _input=None):
    module.to('cuda')


def lazy_loading(func, module):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        load_weights(module)
        res = func(*args, **kwargs)
        release_weights(module)
        return res
    return wrapper

def main():
    model = Net().eval()
    data = torch.rand(1, 32)

    for module in [model.fc1, model.fc2]:
        # method 1: decorator
        module.forward = lazy_loading(module.forward, module)

        # method 2: hook
        # module.register_forward_hook(release_weights)
        # module.register_forward_pre_hook(load_weights)

    data = data.cuda()

    with torch.no_grad():
        model = torch.jit.script(model)
        output = model(data)


main()

I’ve tried:

  1. For adding decorators (before/after torch.jit.script), the decorator seems to be ignored.
  2. For adding hooks before torch.jit.script, it comes Type mismatch with hooks.
  3. For adding hooks after torch.jit.script, it raises RuntimeError: register_forward_hook is not supported on ScriptModules.

I wonder is it possible to script the model and lazily load parameters at the same time? (without directly modifying model internal code)
Any feedback would be greatly appreciated. Thanks!

torch version: 1.10.1

With method 2 + type hints (link),

...
def release_weights(module, _input: Tuple[torch.Tensor], _output):
    module.weight.data = module.weight.to('cpu')
...

it comes:

RuntimeError:
Tried to set an attribute: data on a non-class: Tensor:
  File "/home/siahuat0727/test.py", line 22
def release_weights(module, _input: Tuple[torch.Tensor], _output):
    module.weight.data = module.weight.to('cpu')
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

So the main issue would be: Is it possible to move ScriptModule parameters across devices?

Here’s a similar issue Problem with jit TorchScript while copying data between GRUs · Issue #28267 · pytorch/pytorch (github.com).