Implementation of Large Margin Deep Networks for Classification

Hi, I am trying to implement the paper Large Margin Deep Networks for Classification, based on this TensorFlow code. Essentially, the idea is to assume a linear distance between each class and create an SVM like objective. So we need to find the norm of the gradient of the input wrt to each probability (up to num_classes-1) in each loop.

However, whereas the TensorFlow code works as expected, this code does not. After this first batch the loss tends to the upper distance bound (200). I am not sure where the problem is and would really appreciate any help.

The really odd part is if I comment out this line,

data.grad.data.zero_()

the model trains fine, reaching > 98% test accuracy (albeit slowly). So accumulation of gradients helps learning, which means the size of norm of gradients if we didn’t comment out this line rapidly decreases, leading to no learning.

If you want to try this for yourself the method below is a direct drop-in replacement for train() in the official PyTorch MNIST example.

def maximum_with_relu(a, b):
  return a + F.relu(b - a)

def train(args, model, device, train_loader, optimizer, epoch, num_classes=10):
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        data.requires_grad_()
        optimizer.zero_grad()
        output = model(data)

        # get the difference between true class probability 
        # and all other probabilities (for MNIST this is the topk = 9)
        class_prob = torch.exp(output)
        correct_class_prob = class_prob.gather(1, target.unsqueeze(1))
        other_class_prob = class_prob.clone().requires_grad_()
        other_class_prob[np.arange(len(other_class_prob)), target] = 0
        topk_other_class_prob, topk_other_class_idx = torch.topk(other_class_prob, k=num_classes-1, dim=1)
        difference_prob = correct_class_prob - topk_other_class_prob

        # find difference_prob / || grad of input wrt difference prob ||_2
        # loop through each probability found above and use this as a loss 
        # to find the norm of the gradient of the input
        for i in range(difference_prob.size(0)):
            for j in range(difference_prob.size(1)):
                difference_prob[i][j].backward(retain_graph=True)
                difference_prob[i][j] /= (1e-6 + data.grad.data.norm(p=2))
                data.grad.data.zero_()

        distance_to_boundary = torch.mean(difference_prob, 1)

        # Distances to consider between distance_upper and distance_lower bounds
        distance_upper = 200
        distance_lower = 200 * (1 - 2)

        # Enforce lower bound.
        loss_layer = maximum_with_relu(distance_to_boundary, distance_lower)

        # Enforce upper bound.
        loss = maximum_with_relu(
            0, distance_upper - loss_layer) - distance_upper
        loss = torch.mean(loss)
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

1 Like