Why is my model getting stuck?

Hi,

Using this as loss criterion:

class TripletLoss(torch.nn.Module):
    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)
        distance_negative = (anchor - negative).pow(2).sum(1)
        losses = torch.nn.functional.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

And this as a network:

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = torch.nn.Conv2d(1, 16, 5)
        self.pool1 = torch.nn.MaxPool2d(2)
        self.conv2 = torch.nn.Conv2d(16, 64, 3)
        self.pool2 = torch.nn.MaxPool2d(2)
        self.conv3 = torch.nn.Conv2d(64, 128, 3)
        self.pool3 = torch.nn.MaxPool2d(2)
        self.linear1 = torch.nn.Linear(1152, 512)
        self.linear2 = torch.nn.Linear(512, 64)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv1(x))
        x = self.pool1(x)
        x = torch.nn.functional.relu(self.conv2(x))
        x = self.pool2(x)
        x = torch.nn.functional.relu(self.conv3(x))
        x = self.pool3(x)
        x = x.view(x.shape[0], -1)
        x = torch.nn.functional.relu(self.linear1(x))
        x = torch.nn.functional.relu(self.linear2(x))

        return x

I train the network with random data. I also print the value of the loss returned by the criterion.

if __name__ == '__main__':
    model = Net()
    criterion = TripletLoss(0.5)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    model.train()

    for x in range(10000):
        optimizer.zero_grad()

        x = torch.tensor(np.random.normal(size=[1, 1, 40, 40])).float()
        y = torch.tensor(np.random.normal(size=[1, 1, 40, 40])).float()
        z = torch.tensor(np.random.normal(size=[1, 1, 40, 40])).float()
        anchor = model(x)
        positive = model(y)
        negative = model(z)

        loss = criterion(anchor, positive, negative)
        print(loss)

        loss.backward()
        optimizer.step()

As the input data is random I expect for the network to learn nothing(i.e., get random error values).
The problem that I have is that the loss quickly converges to 0.5. The weights of the different layers of the network also converge to stable values.

...
tensor(0.5000, grad_fn=<MeanBackward1>)
tensor(0.5000, grad_fn=<MeanBackward1>)
tensor(0.5000, grad_fn=<MeanBackward1>)
...

The same happens when using real data(non random data).
Why is this happening?

Could you check. If the outputs converge towards a constant value, e.g. a zero tensor?
In that case all distances would be zero and only the margin would push your loss to 0.5.
Could you try to lower the learning rate a bit and check, if the loss decreases further?

if you are considering L2 loss between two vectors, these lines seem to be wrong. Can you recheck?

Hi, ptrblck.

Yes they converge to zero. The following code, after few iterations, prints:

print(anchor)
print(loss)

# tensor([[0., 0., 0., ..... 0., 0.]], grad_fn=<ReluBackward0>)
# tensor(0.5000, grad_fn=<MeanBackward1>)

With a learning rate of 1e-3 the same happens, it just takes more iterations for the values to converge. With a learning rate of 1e-4 I haven’t run the code long enough for it to converge but with every iteration there is fewer non zero values in the output vectors(i.e, it’s converging towards a zero vector).

The above results are when using random input data, and the above results, kind of make sense to me. Since the input data is random, the network won’t be able to do any better than .05, and the best way to converge to that error value is to always return output vectors of zeros.
Is my reasoning wright?

On the other hand when using real data the following happens:

The above plot is the mean(100 values) error for 10000 iterations using a learning rate of 1e-4. The above is great, but what bothers is that the feature vectors returned by the network are mostly zeros:

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.1780, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2754,
         0.0534, 0.2608, 0.0000, 0.0000, 0.2360, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0242, 0.0000, 0.0000, 0.3901,
         0.0000, 0.0557, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000]], grad_fn=<ReluBackward0>)

To my understanding this is because of the Relu activation functions I am using between every layer of the network. Could I get better results removing some of the relu layers in my network.

Thanks.

Yeah I am considering the L2, but I don’t see anything wrong with the code. Can you elaborate?

Try to remove the last relu activation in your model, which should make the output more diverse.
Also, I’m not sure if it’s a good idea to re-sample the data in the training loop, as your model won’t be able to learn these random samples.

I could pretty easy overfit your model after applying these two changes.

Thats the first thing I did just after posting, and from the limited testing I done so far, I’d say that the network is learning faster.

I am afraid I don’t understand what you mean by re-sample the data. My current training loop(using real data, not the random data example given above) looks like this:

for index, triplet in enumerate(data):
    optimizer.zero_grad()

    anchor = model(torch.tensor([[triplet[0]]]).float()
    positive = model(torch.tensor([[triplet[0]]]).float()
    negative = model(torch.tensor([[triplet[0]]]).float()

    loss = criterion(anchor, positive, negative)
    loss.backward()
    optimizer.step()

    # print error and others every 100 iters.

Thanks.

oh yes. I misread the code. It seems correct. Pl. ignore my msgs.

Is it a typo or are you using the same data for the anchor, positive, and negative sample?
Could you try to just use very few data sample (e.g. 10 samples) and try to overfit your model on this data?

Yeah it’s a typo in the example code posted. In my actual code I have a python One-Liner:

anchor, positive, negative = [model(torch.tensor([[x]]).float()) for x in triplet]

The above example corrected would be:

for index, triplet in enumerate(data):
    optimizer.zero_grad()

    anchor = model(torch.tensor([[triplet[0]]]).float()
    positive = model(torch.tensor([[triplet[1]]]).float()
    negative = model(torch.tensor([[triplet[2]]]).float()

    loss = criterion(anchor, positive, negative)
    loss.backward()
    optimizer.step()

    # print error and others every 100 iters.

My intend was to post clearer(easier to read) code, I’ve clearly failed miserably since it caused more confusion.

No worries! I was just worried that this typo might make the training impossible.
Were you successful testing a small data sample?

I believe I’ve been, but I am also so fresh in the field of deep learning that I am never sure whether I am making progress or not.

Here is a plot of the loss evolution over the training set(8 different classes).

And here are the features of four different classes(not from the training set).

I know this features are crap but some separation between the different classes can be seen.

Thanks.

For just 10 samples, the loss should decrease basically to zero. Your model still seems to have some trouble learning the data. I would play around with some hyperparameters (learning rate) and the model architecture as the next step and force the model to learn this tiny data perfectly before digging any further.