Triplet Loss on MNIST

Hi everyone, I’m trying to us AlexNet and triplet loss to deal with MNIST set. I think there’s something wrong with how I choose the positive and negative in the triplet. I try to find the furthest positive and closest negative. The average loss of the triplet sticks at 1, which is the margin of the triplet. I tried to adjust the learning rate from 0.01 to 0.000001. However, it doesn’t work.

model = AlexNet(1).to(device)
criterion = TripletLoss(margin=1)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    
    running_loss = 0
  
    for i, (data, label) in enumerate(train_loader):

        a_input = Variable(data).to(device).float()
        a_output = model(a_input)

        p_index, n_index = [], []

        for index, anchor in enumerate(a_output):

            max_dist, min_dist = 0, 100000
            max_index, min_index = 0, 0
            for index_, anchor_ in enumerate(a_output):
                if index == index_:
                    continue
                a0 = torch.Tensor.cpu(anchor).data.numpy()
                a1 = torch.Tensor.cpu(anchor_).data.numpy()
                dist = np.linalg.norm(a0-a1)
                if max_dist < dist and label[index] == label[index_]:
                    max_dist, max_index = dist, index_
                if min_dist > dist and label[index] != label[index_]:
                    min_dist, min_index = dist, index_
            p_index.append(max_index)
            n_index.append(min_index)

        p_output = torch.stack([a_output[i] for i in p_index])
        n_output = torch.stack([a_output[i] for i in n_index])

        optimizer.zero_grad()

        loss = criterion(a_output, p_output, n_output)
        loss.backward()

        optimizer.step()
        
        running_loss += loss.data
        
        if i % 100 == 99:
            print(f'epoch {epoch+1}, batch {i+1}, loss: {running_loss}')
            running_loss = 0
            
print('Finished')

The output loss is always a little bit above 100.

Anyone knows what’s happening?

1 Like