Hi! I’m recently trying to adapt meta learning(https://arxiv.org/pdf/1703.03400.pdf) to my custom module, so for a defined module, e.g., , the widely used vgg16, we would usually use it as:
fc_out = vgg16(x)
loss = F.crossEntropy(fc_out, label)
loss.backword()
Is there some feature of pytorch that, we can simply:
fc_out = vgg16(x, params)
I observed in the widely used maml pytorch implementation https://github.com/katerakelly/pytorch-maml, it rewrite the forward function as follows, but for more complex case where I have very complex module which contains more self defined modules, rewrite all module would be arduous
class OmniglotNet(nn.Module):
‘’’
The base model for few-shot learning on Omniglot
‘’’
def __init__(self, num_classes, loss_fn, num_in_channels=3):
super(OmniglotNet, self).__init__()
# Define the network
self.features = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(num_in_channels, 64, 3)),
('bn1', nn.BatchNorm2d(64, momentum=1, affine=True)),
('relu1', nn.ReLU(inplace=True)),
('pool1', nn.MaxPool2d(2,2)),
('conv2', nn.Conv2d(64,64,3)),
('bn2', nn.BatchNorm2d(64, momentum=1, affine=True)),
('relu2', nn.ReLU(inplace=True)),
('pool2', nn.MaxPool2d(2,2)),
('conv3', nn.Conv2d(64,64,3)),
('bn3', nn.BatchNorm2d(64, momentum=1, affine=True)),
('relu3', nn.ReLU(inplace=True)),
('pool3', nn.MaxPool2d(2,2))
]))
self.add_module('fc', nn.Linear(64, num_classes))
# Define loss function
self.loss_fn = loss_fn
# Initialize weights
self._init_weights()
def forward(self, x, weights=None):
''' Define what happens to data in the net '''
if weights == None:
x = self.features(x)
x = x.view(x.size(0), 64)
x = self.fc(x)
else:
x = conv2d(x, weights['features.conv1.weight'], weights['features.conv1.bias'])
x = batchnorm(x, weight = weights['features.bn1.weight'], bias = weights['features.bn1.bias'], momentum=1)
x = relu(x)
x = maxpool(x, kernel_size=2, stride=2)
x = conv2d(x, weights['features.conv2.weight'], weights['features.conv2.bias'])
x = batchnorm(x, weight = weights['features.bn2.weight'], bias = weights['features.bn2.bias'], momentum=1)
x = relu(x)
x = maxpool(x, kernel_size=2, stride=2)
x = conv2d(x, weights['features.conv3.weight'], weights['features.conv3.bias'])
x = batchnorm(x, weight = weights['features.bn3.weight'], bias = weights['features.bn3.bias'], momentum=1)
x = relu(x)
x = maxpool(x, kernel_size=2, stride=2)
x = x.view(x.size(0), 64)
x = linear(x, weights['fc.weight'], weights['fc.bias'])
return x