Hi,
I am trying out below network. But, I am facing the problem that training and validation loss both goes to zero after some batches. This happens even when the margin in triplet loss is a high value (like 10000).
It would be very helpful if someone can suggest what could possibly be wrong in the process that I have chosen.
import torch
from torch.autograd import Variable
import torch.nn as nn
import torchvision.models as models
import os
# Setting GPU
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
num_inputs = 2
resnet = models.resnet50(pretrained=True)
mod_resnet = list(resnet.children())
mod_resnet.pop()
resnet_model = nn.Sequential(*mod_resnet)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# Resnet model
self.resnet = resnet_model
self.fc1 = nn.Linear(2048, 512, bias=True)
self.bn = nn.BatchNorm1d(512)
# self.relu = nn.PReLU()
# self.bn_merge1 = nn.BatchNorm1d(num_inputs)
self.conv1 = nn.Conv1d(in_channels=num_inputs, out_channels=1, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm1d(512)
# self.bn_merge2 = nn.BatchNorm1d(num_inputs)
self.conv2 = nn.Conv1d(in_channels=num_inputs, out_channels=1, kernel_size=1, stride=1)
self.bn2 = nn.BatchNorm1d(512)
self.triplet = TripletDist()
def forward_once(self, x):
output = self.resnet(x)
N, C, H, W = output.size()
output = output.view(N, C * H * W)
output = self.fc1(output)
# output = self.relu(output)
output = self.bn(output)
return output
def forward(self, in_1, in_n_2, in_n_3):
'''
:param in_1: First input of size batch_size x 3 x 224 x 224
:param in_n_2: A list of inputs similar to in_1 ; each of size batch_size x 3 x 224 x 224
:param in_n_3: A list of inputs very different from in_1; each of size batch_size x 3 x 224 x 224
:return:
'''
output_2_list = []
output_3_list = []
output_1 = self.forward_once(in_1)
for j in range(0, num_inputs):
output_2_list.append(self.forward_once(in_n_2[j]))
for j in range(0, num_inputs):
output_3_list.append(self.forward_once(in_n_3[j]))
# output_bn_n_2 = self.bn_merge1(torch.stack(output_2_list, 1))
output_bn_n_2 = torch.stack(output_2_list, 1)
output_n_2 = self.conv1(output_bn_n_2).squeeze()
output_n_2 = self.bn1(output_n_2)
# output_bn_n_3 = self.bn_merge2(torch.stack(output_3_list, 1))
output_bn_n_3 = torch.stack(output_3_list, 1)
output_n_3 = self.conv2(output_bn_n_3).squeeze()
output_n_3 = self.bn1(output_n_3)
model_output = self.triplet(output_1, output_n_2, output_n_3)
return model_output, output_1, output_n_2, output_n_3
class TripletDist(nn.Module):
# finds pairwise 2-norm distances, concatenates and returns a vector
def __init__(self):
super(TripletDist, self).__init__()
self.dist = nn.PairwiseDistance()
def forward(self, *x):
pair1_2norm_dist = self.dist(x[0], x[1])
pair2_2norm_dist = self.dist(x[0], x[2])
output = torch.cat((pair1_2norm_dist, pair2_2norm_dist), 0)
N, C = output.size()
return output.view(N * C)
model = Net()
val_1 = Variable(torch.randn(4, 3, 224, 224).cuda(async=True), volatile=True)
val_n_2 = []
val_n_3 = []
for j in range(num_inputs):
val_n_2.append(Variable(torch.randn(4, 3, 224, 224).cuda(async=True), volatile=True))
val_n_3.append(Variable(torch.randn(4, 3, 224, 224).cuda(async=True), volatile=True))
criterion = nn.TripletMarginLoss(margin=3000, p=2)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)
for i in range(1000):
model.train()
model.cuda()
# wrap inputs in Variables
input_1 = Variable(((torch.randn(4, 3, 224, 224)).cuda(async=True)))
input_n_2 = []
input_n_3 = []
for j in range(num_inputs):
input_n_2.append(Variable(((torch.randn(4, 3, 224, 224)).cuda(async=True))))
input_n_3.append(Variable(((torch.randn(4, 3, 224, 224)).cuda(async=True))))
# Training
# zero the parameter gradients
optimizer.zero_grad()
dist, out_1, out_n_2, out_n_3 = model(input_1, input_n_2, input_n_3)
loss = criterion(out_1, out_n_2, out_n_3)
loss.backward()
train_loss = loss.data[0]
# Take one optimization step
optimizer.step()
if i % 10 == 0:
model.eval()
val_dist, val_anc_out, val_pos_out, val_neg_out = model(val_1, val_n_2, val_n_3)
triplet_margin_val = criterion(val_anc_out, val_pos_out, val_neg_out)
val_loss = triplet_margin_val.data[0]
print(i, "Train_loss: ", train_loss, ", VAL_LOSS: ", val_loss)