Adding an STN to prevent network from learning

Hey,

I want to build a CNN that contains several Spatial transformer networks (STNs). As a first step I wanted to add one STN, but unfortunately this makes the CNN perform very poorly.The same CNN without the STN (commenting out the line β€œx = self.stn(x)”) reaches 99% validation accuracy within 5 epochs. As soon as the line is added, the result is only about 6 percent after the second epoch at the latest, and here it remains, no matter how many epochs are trained.

So I assume that I have some bug in the code that makes the CNN get stuck in such a bad area, but I can’t find it.

I hope one of you can help me. Here is my code:

from functools import partial

import torch
from deconvolution.models.deconv import FastDeconv
from torch import nn as nn
from torch.nn import functional as F
import PIL

class Stn(nn.Module):
    def __init__(self):
        super(Stn, self).__init__()
        # Spatial transformer localization-network
        self.loc_net = nn.Sequential(
            nn.MaxPool2d(2,2),
            nn.Conv2d(3, 250, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Conv2d(250, 250, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(2,2),
            nn.Flatten(),
            nn.Linear(250*6*6, 250),
            nn.ReLU(),
            nn.Linear(250, 6)
        )

    def forward(self, x):
        xs = self.loc_net(x)
        theta = xs.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size())
        x = F.grid_sample(x, grid)
        return x


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        deconv = partial(FastDeconv)
        stn = partial(Stn)
        self.stn = stn()
        self.deconv = deconv(3, 3, kernel_size=3, stride=1, padding=1)
        self.conv1 = nn.Conv2d(3, 200, (7, 7), stride=1, padding=2)
        self.conv2 = nn.Conv2d(200, 250, (4, 4), stride=1, padding=2)
        self.conv3 = nn.Conv2d(250, 350, (4, 4,), stride=1, padding=2)
        self.fc1 = nn.Linear(12600, 400)
        self.fc2 = nn.Linear(400, 43)
        self.lcn1 = LocalContextNorm(200)
        self.lcn2 = LocalContextNorm(250)
        self.lcn3 = LocalContextNorm(350)
        self.dropout = nn.Dropout2d(0.2)


    def forward(self, x):
     #### Hidden 1 ####
      x = self.deconv(x)
      #x = self.stn(x)
      
      x = F.max_pool2d(F.relu(self.conv1(x)), 2)
      x = self.dropout(x)
      x = self.lcn1(x)
      x = self.dropout(x)

      #### Hidden 2 ####
      x = F.max_pool2d(F.relu(self.conv2(x)), 2)
      x = self.lcn2(x)
      x = self.dropout(x)

      #### Hidden 3 ####
      x = F.max_pool2d(F.relu(self.conv3(x)), 2)
      x = self.lcn3(x)
      x = self.dropout(x)

      #### Hidden 4 ####
      x = torch.flatten(x, 1)
      x = self.fc1(x)
      x = F.relu(x)

      #### Out ####
      x = self.fc2(x)
      return x
1 Like