Training two models with a single optimizer and loss function is not learning anything

Hello,

I was trying to train a model net that creates embeddings of images and a classifier clas that discriminates the class of two embbedings. My training routine tries to learn both the embedder and the classifier on a end-to-end fashion.

At testing time, the model does not learn anything and I’m not sure why.

The code that I use is shown below:

net = Embedder()
net.train()
net.to(device)

clas = Classifier()
clas.train()
clas.to(device) 

optimizer = optim.Adam(list(net.parameters()) + list(clas.parameters()), lr=lr) 

criterion = nn.CrossEntropyLoss()

The training loop is:

images1 = images1.to(device, dtype=torch.float)
images2 = images2.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.long)

# ================= FORWARD=================

e1 = net(images1)
e2 = net(images2)
rta = classifier(e1, e2)

loss = criterion(rta, labels)

# ================= BACKWARD =================

optimizer.zero_grad()
loss.backward()
optimizer.step()

I think the problem is that each embedding is computed on different forwards but I’m not sure.

I already done hyper-parameters searching and checked that the dataloader is working fine. Also, I tested wich optimizer to use, non of them do.

Thanks for any advice <3!

The code looks generally fine.
One check you could additionally do is to check for valid gradients in all modules.
E.g.:

...
optimizer.step()
print(net.some_layer.weight.grad)
# or 
for name, param in net.named_parameters():
    print(name, param.grad)
# the same for classifier

Thank for your reply!

I found the main problems in my code at the definition of the network.

I checked the gradients manually and they are near 0 (more or less 1e-11) .

My “fast” solution was to redefine the embedder and create a new joint model JointModel and it worked perfectly.

class JointModel(nn.Module):
    def __init__(self, net, classifier):
        super(JointModel, self).__init__()
        self.emb = net
        self.clas = classifier

    def forward(self, x1, x2):
        B = images1.size(0)
        X = torch.cat((x1, x2), dim=0)
        X = self.emb(X)
        x1, x2 = X[:B, ...], X[B:, ...]
        return self.clas(x1, x2)

Now, my doubt now is why my last simple network had a problem with the gradients. (It is an AlexNet with few variants). I added the code next.

Note: I already varied the initial learning rate from 1e-1 up to 1e-4.

The network was initialized with the default parameters.

class AN(nn.Module):
    def __init__(self, sobel="RGB", bn=False, init_weights=False, norm=False):
        super(AN, self).__init__()
        self.norm = norm
        self.bn = bn
        assert (sobel == "RGB") or (sobel == "Edges") or (sobel == "RGB+Edges")

        self.edges = sobel
        if sobel == "Edges" or sobel == "RGB+Edges":
            grayscale = nn.Conv2d(3, 1, kernel_size=1, stride=1, padding=0)
            grayscale.weight.data.fill_(1.0 / 3.0)
            grayscale.bias.data.zero_()
            sobel_filter = nn.Conv2d(1, 2, kernel_size=3, stride=1, padding=1)
            sobel_filter.weight.data[0, 0].copy_(
                torch.FloatTensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]])
            )
            sobel_filter.weight.data[1, 0].copy_(
                torch.FloatTensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])
            )
            sobel_filter.bias.data.zero_()
            self.sobel = nn.Sequential(grayscale, sobel_filter)
            for p in self.sobel.parameters():
                p.requires_grad = False

        if sobel == "Edges":
            self.conv1 = nn.Conv2d(2, 64, kernel_size=11, padding=5, stride=4)
        elif sobel == "RGB+Edges":
            self.conv1 = nn.Conv2d(5, 64, kernel_size=11, padding=5, stride=4)
        else:
            self.conv1 = nn.Conv2d(3, 64, kernel_size=11, padding=5, stride=4)

        self.relu1 = nn.ReLU(inplace=True)
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
        self.relu2 = nn.ReLU(inplace=True)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU(inplace=True)
        self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
        self.relu4 = nn.ReLU(inplace=True)
        self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.relu5 = nn.ReLU(inplace=True)
        self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        if bn:
            self.bn1 = nn.InstanceNorm2d(64)
            self.bn2 = nn.InstanceNorm2d(192)
            self.bn3 = nn.InstanceNorm2d(384)
            self.bn4 = nn.InstanceNorm2d(256)
            self.bn5 = nn.InstanceNorm2d(256)

        if init_weights:
            self._initialize_weights()

    def forward(self, x):

        if self.edges == "Edges":
            x = self.sobel(x)
        elif self.edges == "RGB+Edges":
            x = torch.cat((self.sobel(x), x), dim=1)

        x = self.conv1(x)
        if self.bn:
            x = self.bn1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        if self.bn:
            x = self.bn2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        if self.bn:
            x = self.bn3(x)
        x = self.relu3(x)

        x = self.conv4(x)
        if self.bn:
            x = self.bn4(x)
        x = self.relu4(x)

        x = self.conv5(x)
        if self.bn:
            x = self.bn5(x)
        x = self.relu5(x)
        x = self.maxpool3(x)
        x = x.view(x.size(0), -1)

        if self.norm:
            x = F.normalize(x)

        return x

    def _initialize_weights(self):
        print('Initiating network weights')
        for _, m in enumerate(self.modules()):
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                for i in range(m.out_channels):
                    m.weight.data[i].normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

Thanks again for your reply :slight_smile: