That code by itself is definitely not enough to accomplish what you’re looking for. A few things to consider -
- Your model will still need a loss function in some context. I am not sure what your plan is for this, but it isn’t too important at the moment.
- When you say the embedding, do you mean the output of the model or the fully connected layer after the embedding? Depending on your application this might be an important distinction. In the ResNet definition, there is no hidden layer between the average pooling and output.
- The standard resnet model in pytorch doesn’t support what you’re looking for by default, but it is not a particularly hard change to make. See below.
import torch.nn as nn
import math
from torchvision.models.resnet import BasicBlock, conv3x3, Bottleneck
class ResNetEmbedding(nn.module):
def __init__(self, block, layers, embedding_dim=256):
self.inplanes = 64
super(ResNetEmbedding, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, embedding_dim)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet18( **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNetEmbedding(BasicBlock, [2, 2, 2, 2], **kwargs)
return model
Note, a lot of this can just be imported as a subclass of ResNet, but this should be close to what you need. Please do heed my note about eh output dimensionality of the ResNet model.