Hi, I trained a model that training loss is decreasing nicely. But after some epochs, it increases and decreases again. Is this normal or there is something wrong in my network. Below you can find the code and loss values. Any help is really appreciated.
class MLP(nn.Module):
def __init__(self,in_dim):
super(MLP,self).__init__()
self.head = nn.Sequential(nn.Linear(in_dim,HPS['mlp_hidden_size']),
nn.BatchNorm1d(HPS['mlp_hidden_size']),
nn.ReLU(),
nn.Linear(HPS['mlp_hidden_size'],HPS['projection_size']))
def forward(self,x):
x= self.head(x)
return(x)
class network(nn.Module):
def __init__(self,net):
super(network,self).__init__()
self.net = net
resnet = models.resnet18(pretrained=False, progress=True)
## Here we get representations from avg_pooling layer
self.encoder = torch.nn.Sequential(*list(resnet.children())[:-1])
self.projection = MLP(in_dim= resnet.fc.in_features)
self.prediction = MLP(in_dim= HPS['projection_size'])
def forward(self,x):
embedding = self.encoder(x)
#print(embedding.size())
embedding = embedding.view(embedding.size()[0],-1)
project = self.projection(embedding)
if self.net=='target':
return(project)
predict = self.prediction(project)
return(predict)
transform=transforms.Compose([transforms.ToTensor(),
transforms.Normalize([0.491, 0.482, 0.447], [0.247, 0.243, 0.261])])
class pair_aug(object):
def __init__(self,transform):
self.transform = transform
def __call__(self,img):
if self.transform:
img1= self.transform(img)
img2= self.transform(img)
return(img1,img2)
def main(num_epochs):
dataset = datasets.CIFAR10('./data', train=True, transform=pair_aug(transform), download=True)
trainloader = DataLoader(dataset,batch_size=HPS['batch_size'],num_workers=4,shuffle=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
online_network = network('online').to(device)
target_network = network('target').to(device)
optimizer = optim.Adam(online_network.parameters(),HPS['optimizer_config']['lr'])
# initilizing target_network
for online_params, target_params in zip(online_network.parameters(), target_network.parameters()):
target_params.data.copy_(online_params.data) # initialize
target_params.requires_grad = False
for epoch in range(num_epochs):
total_loss = 0
for (img1,img2),_ in trainloader:
img1= img1.to(device)
img2= img2.to(device)
loss = loss_fn(online_network,target_network,img1,img2)
total_loss+=loss.item()
## Update online network
optimizer.zero_grad()
loss.backward()
optimizer.step()
##Update target network
## HPS['base_target_ema']
tau = update_tau(epoch)
for online_params,target_params in zip(online_network.parameters(),target_network.parameters()):
target_params.data = tau*target_params +(1-tau)*online_params
print(f'Epoch: {epoch} Epoch_Loss: {total_loss/len(trainloader):0.4f}')
def loss_fn(online_network,target_network,img1,img2):
onl_pred1 = online_network(img1)
onl_pred2 = online_network(img2)
with torch.no_grad():
tar_proj1 = target_network(img1)
tar_proj2 = target_network(img2)
def reg_loss(x,y):
x = F.normalize(x, dim=1)
y = F.normalize(y, dim=1)
return(2-2 * (x * y).sum(dim=-1))
loss = reg_loss(onl_pred1,tar_proj2)
loss+= reg_loss(onl_pred2,tar_proj1)
### Note that in implimentation they return loss.mean()
### Here we compute loss per batch
return(loss.mean())
def update_tau(epoch):
tau = 1-(1-HPS['base_target_ema'])*(math.cos(math.pi*epoch/HPS['num_epochs']))/2
return(tau)
## Training loss
Epoch: 0 Epoch_Loss: 0.1489
Epoch: 1 Epoch_Loss: 0.0108
Epoch: 2 Epoch_Loss: 0.0100
Epoch: 3 Epoch_Loss: 0.0087
Epoch: 4 Epoch_Loss: 0.0054
Epoch: 5 Epoch_Loss: 0.0092
Epoch: 6 Epoch_Loss: 0.0855
Epoch: 7 Epoch_Loss: 0.2342
Epoch: 8 Epoch_Loss: 0.0304
Epoch: 9 Epoch_Loss: 0.0142