Loss is increased after some epochs

Hi, I trained a model that training loss is decreasing nicely. But after some epochs, it increases and decreases again. Is this normal or there is something wrong in my network. Below you can find the code and loss values. Any help is really appreciated.

class MLP(nn.Module): 
    def __init__(self,in_dim):
        super(MLP,self).__init__()
        self.head = nn.Sequential(nn.Linear(in_dim,HPS['mlp_hidden_size']),
                                 nn.BatchNorm1d(HPS['mlp_hidden_size']),
                                 nn.ReLU(),
                                 nn.Linear(HPS['mlp_hidden_size'],HPS['projection_size']))
        
    def forward(self,x):
        x= self.head(x)
        return(x)
    
class network(nn.Module):
    
    def __init__(self,net):  
        super(network,self).__init__()
        self.net = net
        resnet = models.resnet18(pretrained=False, progress=True)
        
        ## Here we get representations from avg_pooling layer
        self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
        self.projection = MLP(in_dim= resnet.fc.in_features) 
        self.prediction = MLP(in_dim= HPS['projection_size'])
        
    def forward(self,x):
        
        embedding = self.encoder(x)
        #print(embedding.size())
        
        embedding = embedding.view(embedding.size()[0],-1)
        project = self.projection(embedding)
        
        if self.net=='target':
            return(project)
        
        predict = self.prediction(project)
        return(predict)

transform=transforms.Compose([transforms.ToTensor(),
                                   transforms.Normalize([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])])

class pair_aug(object):
    def __init__(self,transform):
        self.transform = transform
        
    def __call__(self,img):
        if self.transform:
            img1= self.transform(img)
            img2= self.transform(img)
        return(img1,img2)
    

def main(num_epochs):
    
    dataset = datasets.CIFAR10('./data', train=True, transform=pair_aug(transform), download=True)
    
    trainloader = DataLoader(dataset,batch_size=HPS['batch_size'],num_workers=4,shuffle=True)
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    online_network = network('online').to(device)
    
    target_network = network('target').to(device)
    
    optimizer = optim.Adam(online_network.parameters(),HPS['optimizer_config']['lr'])
    
    # initilizing target_network
    for online_params, target_params in zip(online_network.parameters(), target_network.parameters()):
            target_params.data.copy_(online_params.data)  # initialize
            target_params.requires_grad = False
    
    
    
    for epoch in range(num_epochs):
        
        total_loss = 0
        
        for (img1,img2),_ in trainloader:
            
            img1= img1.to(device)
            
            img2= img2.to(device)
            
            loss = loss_fn(online_network,target_network,img1,img2)
            
            total_loss+=loss.item()
            
            ## Update online network
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
    
            ##Update target network
            ## HPS['base_target_ema']
            tau = update_tau(epoch)
    
            for online_params,target_params in zip(online_network.parameters(),target_network.parameters()):
            
                target_params.data = tau*target_params +(1-tau)*online_params
                
        print(f'Epoch: {epoch} Epoch_Loss: {total_loss/len(trainloader):0.4f}')
            
                   
            
def loss_fn(online_network,target_network,img1,img2):
    
    onl_pred1 = online_network(img1)
    onl_pred2 = online_network(img2)
    
    with torch.no_grad():
        tar_proj1 = target_network(img1)
        tar_proj2 = target_network(img2)
    
    def reg_loss(x,y):
        x = F.normalize(x, dim=1)
        y = F.normalize(y, dim=1)
        return(2-2 * (x * y).sum(dim=-1))
    
    loss = reg_loss(onl_pred1,tar_proj2)
    
    loss+= reg_loss(onl_pred2,tar_proj1)
    
    ### Note that in implimentation they return loss.mean()
    ### Here we compute loss per batch
    return(loss.mean())
        
        
def update_tau(epoch):
    
    tau = 1-(1-HPS['base_target_ema'])*(math.cos(math.pi*epoch/HPS['num_epochs']))/2
    
    return(tau)       

## Training loss
Epoch: 0 Epoch_Loss: 0.1489
Epoch: 1 Epoch_Loss: 0.0108
Epoch: 2 Epoch_Loss: 0.0100
Epoch: 3 Epoch_Loss: 0.0087
Epoch: 4 Epoch_Loss: 0.0054
Epoch: 5 Epoch_Loss: 0.0092
Epoch: 6 Epoch_Loss: 0.0855
Epoch: 7 Epoch_Loss: 0.2342
Epoch: 8 Epoch_Loss: 0.0304
Epoch: 9 Epoch_Loss: 0.0142

It is not unusual to see a slight increase in training loss once the model converges. However, something looks off here. I would suggest printing accuracy as well to see how the correlation looks, as well as lowering the learning rate.

Thank you for your answer. Here I can’t compute accuracy, since it is an unsupervised learning algorithm and there is no label here. The way that loss is computed is that for every image we compute two transformations of the image and then the loss is the euclidean distance between these transformations that tries to minimize dissimilarity between these images and in this way the network learns. That’s why I do not have any other metric to judge the behavior of the algorithm except loss.