Decaying learning rate spikes center loss

Hello,

I am implementing centerloss in my application. Center loss is introduced in ECCV2016: A Discriminative Feature Learning Approach for Deep Face Recognition. The idea is to cluster features (embeddings) before the last FC layer. This means embeddings’ distances to their cluster center will be reduced using centerloss. centerloss is optimized jointly with crossentropy. So as crossentropy tries to separate features, centerloss will make features of the same class close to each other. At each epoch centers are updated based on the average of embeddings. Centerloss has been implemented by many. I’m mainly using the implementation here with minor tweaks: PyTorch CenterLoss. Here is the implementation:

class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, size_average=True):
        super(CenterLoss, self).__init__()
        self.centers = nn.Parameter(torch.FloatTensor(num_classes, feat_dim))
        self.centerloss = CenterLossFunction.apply
        self.feat_dim = feat_dim
        self.size_average = size_average
        self.reset_params()
    
    def reset_params(self):
        nn.init.kaiming_normal_(self.centers.data.t())

    def forward(self, feat, label):
        batch_size = feat.size(0)
        feat = feat.view(batch_size, -1)
        # To check the dim of centers and features
        if feat.size(1) != self.feat_dim:
            raise ValueError("Center's dim: {0} should be equal to input feature's \
                            dim: {1}".format(self.feat_dim,feat.size(1)))
        batch_size_tensor = feat.new_empty(1).fill_(batch_size if self.size_average else 1)
        loss = self.centerloss(feat, label, self.centers, batch_size_tensor)
        return loss


class CenterLossFunction(Function):
    @staticmethod
    def forward(ctx, feature, label, centers, batch_size):
        ctx.save_for_backward(feature, label, centers, batch_size)
        centers_batch = centers.index_select(0, label.long())
        return (feature - centers_batch).pow(2).sum() / 2.0 / batch_size

    @staticmethod
    def backward(ctx, grad_output):
        feature, label, centers, batch_size = ctx.saved_tensors
        centers_batch = centers.index_select(0, label.long())
        diff = centers_batch - feature
        # init every iteration
        counts = centers.new_ones(centers.size(0))
        ones = centers.new_ones(label.size(0))
        grad_centers = centers.new_zeros(centers.size())

        counts.scatter_add_(0, label.long(), ones)
        grad_centers.scatter_add_(0, label.unsqueeze(1).expand(feature.size()).long(), diff)
        grad_centers = grad_centers/counts.view(-1, 1)
        return - grad_output * diff / batch_size, None, grad_centers, None

In my application (image classification), I have training, validation and test sets. Here is my main script using centerloss optimized jointly with crossentropy:

def main(cfg):
    
    # Loading dataset
    # ---------------
    # train_loader = ...
    # valid_loader = ...
    # test_loader = ...
    
    # Create Model
    # ------------
    model = resnet18(in_channels=1, num_classes=8)
    model = torch.nn.DataParallel(model).to(device)

    # define loss function (criterion) and optimizer
    # ----------------------------------------------
    criterion = {}
    optimizer = {}
    criterion['xent'] = nn.CrossEntropyLoss().to(device)
    optimizer['xent'] = torch.optim.SGD(model.parameters(), cfg['lr'], 
                                        momentum=cfg['momentum'],
                                        weight_decay=cfg['weight_decay'])
    
    criterion['center'] = CenterLoss(8, 512).to(device)
    optimizer['center'] = torch.optim.SGD(criterion['center'].parameters(), cfg['alpha'])

    # training and evaluation
    # -----------------------
    for epoch in range(cfg['epochs']):
        adjust_learning_rate(optimizer['xent'], epoch, cfg)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, cfg)
        # evaluate on validation set
        validate(valid_loader, model, criterion, epoch, cfg)
        # evaluate on test set
        test(test_loader, model, criterion, epoch, cfg)

def train(train_loader, model, criterion, optimizer, epoch, cfg):
    losses = {}
    losses['xent'] = AverageMeter()
    losses['center'] = AverageMeter()
    losses['loss'] = AverageMeter()
    train_acc = AverageMeter()
    y_pred, y_true, y_scores = [], [], []

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):

        images = images.to(device)
        target = target.to(device)
        # compute output
        feat, output = model(images)
        xent_loss = criterion['xent'](output, target)
        center_loss = criterion['center'](feat, target)
        loss = xent_loss + cfg['lamb'] * center_loss

        # measure accuracy and record loss
        acc, pred = accuracy(output, target)
        losses['xent'].update(xent_loss.item(), images.size(0))
        losses['center'].update(center_loss.item(), images.size(0))
        losses['loss'].update(loss.item(), images.size(0))
        train_acc.update(acc.item(), images.size(0))

        # collect for other measures
        y_pred.append(pred)
        y_true.append(target)
        y_scores.append(output.data)

        # compute gradient and do SGD step
        optimizer['xent'].zero_grad()
        optimizer['center'].zero_grad()
        loss.backward()
        optimizer['xent'].step()
        optimizer['center'].step()

    m = metrics(y_pred, y_true, y_scores)
    progress = (...)
    tb_write(losses, train_acc.avg, m, epoch, writer=writer, tag='train')


def validate(valid_loader, model, criterion, epoch, cfg):
    losses = {}
    losses['xent'] = AverageMeter()
    losses['center'] = AverageMeter()
    losses['loss'] = AverageMeter()
    valid_acc = AverageMeter()
    y_pred, y_true, y_scores = [], [], []

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        for i, (images, target) in enumerate(valid_loader):

            images = images.to(device)
            target = target.to(device)

            # compute output
            feat, output = model(images)
            xent_loss = criterion['xent'](output, target)
            center_loss = criterion['center'](feat, target)
            loss = xent_loss + cfg['lamb'] * center_loss

            # measure accuracy and record loss
            acc, pred = accuracy(output, target)
            losses['xent'].update(xent_loss.item(), images.size(0))
            losses['center'].update(center_loss.item(), images.size(0))
            losses['loss'].update(loss.item(), images.size(0))
            valid_acc.update(acc.item(), images.size(0))

            # collect for other measures
            y_pred.append(pred)
            y_true.append(target)
            y_scores.append(output.data)

    m = metrics(y_pred, y_true, y_scores)
    progress = (...)
    tb_write(losses, valid_acc.avg, m, epoch, writer=writer, tag='valid')


def test(test_loader, model, criterion, epoch, cfg):
    '''
    the same as validation
    '''


def adjust_learning_rate(optimizer, epoch, cfg):
    if epoch < 20:
        lr = cfg['lr']
    elif 20 <= epoch < 40:
        lr = cfg['lr'] / 2.0
    else:
        lr = cfg['lr'] / 10.0
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


if __name__ == '__main__':
    # loading config
    cfg = yaml.safe_load(open('config.yaml', 'r')) 
    main(cfg)

I use lamb=0.01, alpha=0.5 for centerloss and an ‘lr=0.05’ for my crossentropy SGD optimizer. I decay learning rate by 2.0 at epoch 20 and by a factor of 10.0 at epoch 40. When I’m logging my scalars:
Screenshot%20from%202019-11-14%2020-40-53
Screenshot%20from%202019-11-14%2020-41-53

You can see that although training loss is decreasing nice and slow. However, centerloss tends to peak at epochs in which I’m decaying the learning rate and then decrease again. How is this happening? Is this normal? I’ve checked model parameters and I don’t see any weights at those epochs to change abnormally.

So when I use prelu layers instead of relu layers in my resnet the problem goes away. No idea why!