Two tower networks stops learning when adding linear layers

I have been trying to train a relatively simple two-tower net for recommendation. I am using PyTorch and the implementation is the following: basically embeddings layers for users and items, optional feed-forward net for both towers, dot product between the user and items representations, and sigmoid.
I am trining with Binary cross entropy loss and Adam optimizer. When I am using only the embeddings, I see improvements from epoch to epoch (loss is decreasing and the evaluation metric are increasing). However, once I add even a single feed-forward layer, the network learns just a bit in the first epoch and then stagnates. I have tried to had code one linear layer with ReLU, to check if the issue is with the way I am creating the list of layers, but this did not change anything.

Has anybody else had a similar problem?

Implementation:

class SimpleTwoTower(nn.Module):
    
    def __init__(self, n_items, n_users, ln):
        super(SimpleTwoTower, self).__init__()
        
        self.ln = ln
        self.item_emb = nn.Embedding(num_embeddings=n_items, embedding_dim=self.ln[0])
        self.user_emb = nn.Embedding(num_embeddings=n_users, embedding_dim=self.ln[0])
       
        
        self.item_layers = [] #nn.ModuleList()
        self.user_layers = [] #nn.ModuleList()
        
        for i, n in enumerate(ln[0:-1]):
            m = int(ln[i+1])
            self.item_layers.append(nn.Linear(n, m, bias=True))
            self.item_layers.append(nn.ReLU())
            
            self.user_layers.append(nn.Linear(n, m, bias=True))
            self.user_layers.append(nn.ReLU())
            
            
        self.item_layers = nn.Sequential(*self.item_layers)
        self.user_layers = nn.Sequential(*self.user_layers)
        
        self.dot = torch.matmul
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, items, users):
        
        item_emb = self.item_emb(items)
        user_emb = self.user_emb(users)
        
        item_emb = self.item_layers(item_emb)
        user_emb = self.user_layers(user_emb)

        dp = self.dot(user_emb, item_emb.t())
        return self.sigmoid(dp)

I assume you are using nn.BCELoss based on your description and the usage of the sigmoid activation? If so, could you remove the sigmoid and replace the criterion with nn.BCEWithLogitsLoss, as it would give you more numerical stability?

Based on your usage of self.dot = torch.matmul it seems you are creating an output tensor of the shape [batch_size, batch_size]. I assume this is expected?

I was indeed using nn.BCELoss with sigmoid activation. I tried with your suggestion, i.e. to remove the sigmoid and train with nn.BCEWithLogitsLoss but I did not get any improvement. This is the loss throughout training for 15 epochs.

Do you have any idea why this may not be working?

To your second question - yes, the idea is to get a score for each user-item pair, so this is why the output tensors is of size [batch_size, batch_size].

No idea yet, what might be wrong. However, it looks as if the loss is pretty stable.
Could you check the gradients of all parameters after calling the .backward() operation via:

for name, param in model.named_parameters():
    print(name, param.grad.abs().sum())

and check their values?

Ok, so the results are:

999 item_emb.weight tensor(0.0596, device='cuda:0')
999 user_emb.weight tensor(0.0473, device='cuda:0')
999 item_layers.0.weight tensor(0.3596, device='cuda:0')
999 item_layers.0.bias tensor(0.0378, device='cuda:0')
999 user_layers.0.weight tensor(0.4759, device='cuda:0')
999 user_layers.0.bias tensor(0.0414, device='cuda:0')
1999 item_emb.weight tensor(0.0013, device='cuda:0')
1999 user_emb.weight tensor(0.0053, device='cuda:0')
1999 item_layers.0.weight tensor(0.0060, device='cuda:0')
1999 item_layers.0.bias tensor(0.0005, device='cuda:0')
1999 user_layers.0.weight tensor(0.1563, device='cuda:0')
1999 user_layers.0.bias tensor(0.0128, device='cuda:0')
2999 item_emb.weight tensor(0., device='cuda:0')
2999 user_emb.weight tensor(0., device='cuda:0')
2999 item_layers.0.weight tensor(0., device='cuda:0')
2999 item_layers.0.bias tensor(0., device='cuda:0')
2999 user_layers.0.weight tensor(0., device='cuda:0')
2999 user_layers.0.bias tensor(0., device='cuda:0')
3999 item_emb.weight tensor(0., device='cuda:0')
3999 user_emb.weight tensor(0., device='cuda:0')
3999 item_layers.0.weight tensor(0., device='cuda:0')
3999 item_layers.0.bias tensor(0., device='cuda:0')
3999 user_layers.0.weight tensor(0., device='cuda:0')
3999 user_layers.0.bias tensor(0., device='cuda:0')

The number in front of each row is the update step. I guess this was what you thought was happening.

I tried two appriaches: changing nn.ReLU to nn.LeakyReLU and separately, I added batch normalization layer after every activation. Both worked, and I think I will stick with the leaky ReLU approach.

I assume this means your model trains now? If so, great to hear!

I have a similar issue. But it still doesn’t work with LeakyReLU or batch normalization.Furthermore the user tower predict same result for different users.

1 Like

I don’t think the problem is with ReLU layers as such,

If we will execute the below,


import torch
import torch.nn as nn

class SimpleTwoTower(nn.Module):
    
    def __init__(self, n_items, n_users, ln):
        super(SimpleTwoTower, self).__init__()
        
        self.ln = ln
        self.item_emb = nn.Embedding(num_embeddings=n_items, embedding_dim=self.ln[0])
        self.user_emb = nn.Embedding(num_embeddings=n_users, embedding_dim=self.ln[0])
       
        
        self.item_layers = [] #nn.ModuleList()
        self.user_layers = [] #nn.ModuleList()
        
        for i, n in enumerate(ln[0:-1]):
            m = int(ln[i+1])
            self.item_layers.append(nn.Linear(n, m, bias=True))
            self.item_layers.append(nn.ReLU())
            
            self.user_layers.append(nn.Linear(n, m, bias=True))
            self.user_layers.append(nn.ReLU())
            
            
        self.item_layers = nn.Sequential(*self.item_layers)
        self.user_layers = nn.Sequential(*self.user_layers)
        
        self.dot = torch.matmul
        self.sigmoid = nn.Sigmoid()

    def forward(self, items, users):
        
        item_emb = self.item_emb(items)
        user_emb = self.user_emb(users)
        
        item_emb = self.item_layers(item_emb)
        user_emb = self.user_layers(user_emb)

        dp = self.dot(user_emb, item_emb.t())
        return self.sigmoid(dp)

# let's create our model here using the snippet as-is
SimpleTwoTower(10000, 1000,  [2**10, 2**8, 2**6])

"""
# output will be something like this

SimpleTwoTower(
  (item_emb): Embedding(10000, 1024)
  (user_emb): Embedding(1000, 1024)
  (item_layers): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU() **#### <<<<<<< isn't this an issue?**
  )
  (user_layers): Sequential(
    (0): Linear(in_features=1024, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=64, bias=True)
    (3): ReLU() **#### <<<<<<< isn't this an issue?**
  )
  (sigmoid): Sigmoid()
)
"""

So, Do we really want to have that ReLU as well? From what i know about the Two-tower model, we will use the raw_output from the last linear layer as the feature vector (i.e. latent vec etc)

Also, Do we want to use embedding layers as well?

cc @ptrblck @Mimi_Lazarova

Please correct my understanding!

Be Safe,
Best,
Aditya.

I have a similar issue the two towers after training give very random results. How did you fix this issue?