Result of autograd does not match that of self defined grad

I just defined a loss and trained a model with random generated loss, the code is like this:
I just defined two identical models, and train them with same inputs, the only difference is the implementation of the critierias, one uses autograd, and the other uses self-defined gradient.

import torch
import torch.nn as nn
import torch.nn.functional as F


class LabelSmoothSoftmaxCEFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, logits, label, lb_smooth):
        num_classes = logits.size(1)
        scores = torch.softmax(logits, dim=1)
        logs = torch.log(scores)
        lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)
        label = (1. - lb_smooth) * lb_one_hot + lb_smooth / num_classes

        ctx.scores = scores
        ctx.label = label

        loss = -torch.sum(logs * label, dim=1)
        return loss.sum()

    @staticmethod
    def backward(ctx, grad_output):
        scores = ctx.scores
        label = ctx.label
        grad = scores - label
        return grad_output*grad, None, None


class LabelSmoothSoftmaxCEV2(nn.Module):

    def __init__(self, lb_smooth):
        super(LabelSmoothSoftmaxCEV2, self).__init__()
        self.lb_smooth = lb_smooth

    def forward(self, logits, label):
        return LabelSmoothSoftmaxCEFunction.apply(logits, label, self.lb_smooth)


class LabelSmoothSoftmaxCEV1(nn.Module):

    def __init__(self, lb_smooth=0.1,):
        super(LabelSmoothSoftmaxCEV1, self).__init__()
        self.lb_smooth = lb_smooth

    def forward(self, logits, label):
        num_classes = logits.size(1)
        scores = torch.softmax(logits, dim=1)
        logs = torch.log(scores)
        lb_one_hot = torch.zeros_like(logits).scatter_(1, label.unsqueeze(1), 1)
        label = (1. - self.lb_smooth) * lb_one_hot + self.lb_smooth / num_classes
        label = label.detach()

        loss = -torch.sum(logs * label, dim=1)
        return loss.sum()



if __name__ == '__main__':
    import torchvision
    import torch
    import numpy as np
    import random
    torch.manual_seed(15)
    random.seed(15)
    np.random.seed(15)
    torch.backends.cudnn.deterministic = True
    net1 = torchvision.models.resnet18(pretrained=True)
    net2 = torchvision.models.resnet18(pretrained=True)
    criteria1 = LabelSmoothSoftmaxCEV1(lb_smooth=0.1)
    criteria2 = LabelSmoothSoftmaxCEV2(lb_smooth=0.1)
    net1.cuda()
    net2.cuda()
    criteria1.cuda()
    criteria2.cuda()

    optim1 = torch.optim.SGD(net1.parameters(), lr=1e-4)
    optim2 = torch.optim.SGD(net2.parameters(), lr=1e-4)

    for _ in range(10):
        inten = torch.randn(16, 3, 224, 244).cuda()
        lbs = torch.randint(0, 1000, (16, )).cuda()
        logits = net1(inten)
        loss = criteria1(logits, lbs)
        optim1.zero_grad()
        loss.backward()
        optim1.step()
        print(loss.item())
        logits = net2(inten)
        loss = criteria2(logits, lbs)
        optim2.zero_grad()
        loss.backward()
        optim2.step()
        print(loss.item())
        print('===='*10)

What is the cause of the difference?

Hi,

How big is the difference.
Unfortunately, floating point computations are not associative. That means that if your version computes the same thing in a different order, it will be slightly different.
When doing gradient descent, such small differences will be amplified.

The way we usually check is by running a single backward pass with double precision numbers and make sure that the output is close enough to the input.
You can also try the gradcheck function to test your implementation.

After around 300k iters of training on randomly generated inputs(float32), the sum of error of the weight matrix becomes 0.0001. Is this normal or not to accumulate the error to be 1e-4 after 300k iters?

Hi,

I would say it is surprisingly good accuracy after that many iterations.
Or that your training has a single nice unique solution and both converged to the same place (most likely).

1 Like