Is it possible to forward a custom module with weights without rewrite the module

Hi! I’m recently trying to adapt meta learning(https://arxiv.org/pdf/1703.03400.pdf) to my custom module, so for a defined module, e.g., , the widely used vgg16, we would usually use it as:

fc_out = vgg16(x)
loss = F.crossEntropy(fc_out, label)
loss.backword()

Is there some feature of pytorch that, we can simply:
fc_out = vgg16(x, params)

I observed in the widely used maml pytorch implementation https://github.com/katerakelly/pytorch-maml, it rewrite the forward function as follows, but for more complex case where I have very complex module which contains more self defined modules, rewrite all module would be arduous
class OmniglotNet(nn.Module):
‘’’
The base model for few-shot learning on Omniglot
‘’’

def __init__(self, num_classes, loss_fn, num_in_channels=3):
    super(OmniglotNet, self).__init__()
    # Define the network
    self.features = nn.Sequential(OrderedDict([
            ('conv1', nn.Conv2d(num_in_channels, 64, 3)),
            ('bn1', nn.BatchNorm2d(64, momentum=1, affine=True)),
            ('relu1', nn.ReLU(inplace=True)),
            ('pool1', nn.MaxPool2d(2,2)),
            ('conv2', nn.Conv2d(64,64,3)),
            ('bn2', nn.BatchNorm2d(64, momentum=1, affine=True)),
            ('relu2', nn.ReLU(inplace=True)),
            ('pool2', nn.MaxPool2d(2,2)),
            ('conv3', nn.Conv2d(64,64,3)),
            ('bn3', nn.BatchNorm2d(64, momentum=1, affine=True)),
            ('relu3', nn.ReLU(inplace=True)),
            ('pool3', nn.MaxPool2d(2,2))
    ]))
    self.add_module('fc', nn.Linear(64, num_classes))
    
    # Define loss function
    self.loss_fn = loss_fn

    # Initialize weights
    self._init_weights()

def forward(self, x, weights=None):
    ''' Define what happens to data in the net '''
    if weights == None:
        x = self.features(x)
        x = x.view(x.size(0), 64)
        x = self.fc(x)
    else:
        x = conv2d(x, weights['features.conv1.weight'], weights['features.conv1.bias'])
        x = batchnorm(x, weight = weights['features.bn1.weight'], bias = weights['features.bn1.bias'], momentum=1)
        x = relu(x)
        x = maxpool(x, kernel_size=2, stride=2) 
        x = conv2d(x, weights['features.conv2.weight'], weights['features.conv2.bias'])
        x = batchnorm(x, weight = weights['features.bn2.weight'], bias = weights['features.bn2.bias'], momentum=1)
        x = relu(x)
        x = maxpool(x, kernel_size=2, stride=2) 
        x = conv2d(x, weights['features.conv3.weight'], weights['features.conv3.bias'])
        x = batchnorm(x, weight = weights['features.bn3.weight'], bias = weights['features.bn3.bias'], momentum=1)
        x = relu(x)
        x = maxpool(x, kernel_size=2, stride=2) 
        x = x.view(x.size(0), 64)
        x = linear(x, weights['fc.weight'], weights['fc.bias'])
    return x

Hi,

If you reload all the weights, you can simply use load_state_dict() either on self or on a deepcopy of self (if you don’t want the original module to be changed), then do the regular forward on this.

Hi, but I need to implement the meta learning training, which involves grad over grad