Training ResNet with latent vector

I would like to have views on; what if I am training a ResNet model on the reduced dimension from an autoencoder (i.e., the output from the encoder’s bottleneck layer directly to the ResNet classifier)?

Thank you everyone in advance.

Your ResNet might still work, if you retrain it, but might waste the feature extraction layers, as your encoder might have already created valid features.
I don’t know what shape the output of the encoder would have, but if it’s not an image tensor, the first convolutions in your ResNet might not be the best choice, as their input might not contain any spatial pattern.

However, as always, you shouldn’t let these assumptions stop you from running some experiments and report your findings. :wink:

1 Like

Yes ptrblck, that’s what I got stuck with. The latent dimension is [Batch-size, 256, 2, 2] and my input has 2x2 which I am not sure how to give it as input. Secondly, if suppose I still consider keeping the feature extraction layers in my Resnet model, then how should I go about changing the input shape, would it be in bottleneck class or resnet class for resnet-50?

class Bottleneck(nn.Module):
expansion = 4

def __init__(self, in_planes, planes, stride=1):
    super(Bottleneck, self).__init__()
    self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm2d(planes)
    self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
    self.bn2 = nn.BatchNorm2d(planes)
    self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm2d(self.expansion*planes)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_planes != self.expansion*planes:
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(self.expansion*planes)
        )

def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))
    out = self.bn3(self.conv3(out))
    out += self.shortcut(x)
    out = F.relu(out)
    return out

class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
    super(ResNet, self).__init__()
    self.in_planes = 64

    self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
    self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
    self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
    self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
    self.linear = nn.Linear(512*block.expansion, num_classes)

def _make_layer(self, block, planes, num_blocks, stride):
    strides = [stride] + [1]*(num_blocks-1)
    layers = []
    for stride in strides:
        layers.append(block(self.in_planes, planes, stride))
        self.in_planes = planes * block.expansion
    return nn.Sequential(*layers)

def forward(self, x):
    out = F.relu(self.bn1(self.conv1(x)))
    out = self.layer1(out)
    out = self.layer2(out)
    out = self.layer3(out)
    out = self.layer4(out)
    out = F.avg_pool2d(out, 4)
    out = out.view(out.size(0), -1)
    out = self.linear(out)
    return out

def ResNet18():
    return ResNet(BasicBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

Your input might currently be too small for a resnet and you might get an error from a conv layer, explaining that the kernel size cannot be bigger than the input shape.

You could try to reshape it, such that the input channels are set to 1 and the spatial sizes are increased by a factor of 16 via:

x = x.view(x.size(0), 1, x.size(2)*16, x.size(3)*16) # [batch_size, 1, 32, 32]

And then change the input channels of the first conv layer to 1 via:

model.conv1 = nn.Conv2d(1, 64, 7, 2, 3, bias=False) # For torchvision.models.resnet18

This would at least increase the spatial size. Note that the default input shape is [3, 224, 224], so your input is still small compared to this shape.

1 Like