I have a large pretrained torch.nn.Module
that is being used by many different torch.nn.Module
s.
For the purpose of example assume it is a pretrained & fixed ResNet image model, that I use for feature generation in many different image classifiers.
How can I best store such model within torch.nn.Module
if I do not want the weights to be stored as part od the model state?
Option 1:
Simply storing it as a child module:
class MyModel(torch.nn.Module):
def __init__(self, resnet: ResNet):
self.resnet = resnet
Would result in its parameters being stored as part of the Module
and increase the checkpoint size.
Option 2
Store the “reference”
class MyModel(torch.nn.Module):
def __init__(self, resnet: ResNet):
self._resnet = [resnet]
With this approach the Module just stores the reference to the ResNet and does not really manage it. The syntax is a bit annoying.
Is there a better way to store weak references?
Also, if I want to only enable device management of the child module, but not weight ownership, is overriding apply
method the best option?