Hi, I am working on a tool to make model only load parameters when needed to reduce peak memory (link). And I would like to take advantage of the torch.jit
optimization but it failed. Here is a reproducible code:
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F # for torch.jit.script
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(32, 32)
self.fc2 = nn.Linear(32, 32)
def forward(self, input):
out = self.fc1(input)
out = self.fc2(out)
return out
def release_weights(module, _input=None, _output=None):
module.to('cpu')
def load_weights(module, _input=None):
module.to('cuda')
def lazy_loading(func, module):
@functools.wraps(func)
def wrapper(*args, **kwargs):
load_weights(module)
res = func(*args, **kwargs)
release_weights(module)
return res
return wrapper
def main():
model = Net().eval()
data = torch.rand(1, 32)
for module in [model.fc1, model.fc2]:
# method 1: decorator
module.forward = lazy_loading(module.forward, module)
# method 2: hook
# module.register_forward_hook(release_weights)
# module.register_forward_pre_hook(load_weights)
data = data.cuda()
with torch.no_grad():
model = torch.jit.script(model)
output = model(data)
main()
I’ve tried:
- For adding decorators (before/after
torch.jit.script
), the decorator seems to be ignored. - For adding hooks before
torch.jit.script
, it comes Type mismatch with hooks. - For adding hooks after
torch.jit.script
, it raisesRuntimeError: register_forward_hook is not supported on ScriptModules
.
I wonder is it possible to script the model and lazily load parameters at the same time? (without directly modifying model internal code)
Any feedback would be greatly appreciated. Thanks!
torch version: 1.10.1