Debugging gradient discrepancy between autograd and manual backward implementation

I am trying to implement the “soft-softmax” loss for multi-label classification that is very briefly described on page 5 (loss function) of https://arxiv.org/abs/1805.00932.

I have an implementation where I computed the gradients myself and implemented a custom loss function with forward-backward (see CategoricalMultiLabelLoss below). I used this blogpost as a reference.

To simplify, I tried to compare this with computing the loss using pytorch operators and let autograd do the work. While the forward passes yield identical loss value, the parameter gradients when I do a backward pass are different.

Can someone help me understand why this might be the case and what I am missing? Is the loss non-differentiable at certain points due to which autograd might not work well? (I didn’t think so but may be I’m wrong)

import torch
import torch.nn as nn

# custom loss implementation
class CategoricalMultiLabelLoss(torch.autograd.Function):  
    @staticmethod
    def forward(ctx, inputs, targets):
        scores = inputs.data
        batch_size = len(scores)
        # Compute cross-entropy loss
        scale_factor = (targets > 0).sum(dim=1, keepdim=True).float()
        logprobs = (scores * targets).div(scale_factor)
        data_loss = - torch.sum(logprobs)/ batch_size
        ctx.save_for_backward(inputs, targets)
        return data_loss

    @staticmethod
    def backward(ctx, grad_output):
        inputs, targets = ctx.saved_tensors
        scale_factor = (targets > 0).sum(dim=1, keepdim=True).float()
        scale_factor_repmat= scale_factor.repeat(1,targets.size(1))
        print(scale_factor.shape, scale_factor_repmat.shape, scale_factor_repmat)
        delta = inputs.data.exp()   # If the class label is 0, the gradient is equal to probs
        delta = scale_factor_repmat *(delta-1) + (1-scale_factor_repmat)*delta
        delta = delta * targets
        return delta, None
    
# dummy model
class DummyModel(nn.Module):
    def __init__(self, num_inputs, emb_size, emb_dropout, hidden_size, num_outputs, pad_index):
        self.num_inputs = num_inputs
        self.emb_size = emb_size
        self.emb_dropout = emb_dropout

        self.hidden_size = hidden_size
        self.num_outputs = num_outputs
        self.pad_index = pad_index

        super(DummyModel, self).__init__()

        self.emb = nn.Embedding(num_inputs, emb_size, padding_idx=pad_index)
        self.emb_dropout = nn.Dropout(p=emb_dropout)        
        self.fc1 = nn.Linear(emb_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_outputs)
        self.logsoftmax = nn.LogSoftmax()

    def forward(self, input_seq, length, masks, debug=False):
        embs = self.emb(input_seq)
        embs_dropout = self.emb_dropout(embs)
        avg = torch.div(
            torch.sum(embs_dropout, axis=1),
            length
        )
        out1 = self.relu(self.fc1(avg))
        out2 = self.fc2(out1)
        
        # constraint: do not predict masks so set these scores to large negative value
        out2.scatter_(1, masks, -1e5)
        return self.logsoftmax(out2)
    
# initialize two identical models with dropout set to zero
num_inputs = 10
emb_size = 2
emb_dropout = 0.0
hidden_size = 1
num_outputs = 5
PAD_INDEX = 0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ma = DummyModel(
    num_inputs,
    emb_size,
    emb_dropout,
    hidden_size,
    num_outputs,
    PAD_INDEX
)
ma.to(DEVICE)

mn = DummyModel(
    num_inputs,
    emb_size,
    emb_dropout,
    hidden_size,
    num_outputs,
    PAD_INDEX
)
mn.to(DEVICE)
mn.load_state_dict(ma.state_dict())

# dummy data
ex = (
    # inputs
    torch.cuda.LongTensor([
        [2, 3, 0],
        [4, 2, 1]
    ]),
    # targets
    torch.cuda.FloatTensor([
        [0, 0, 0, 0.5, 0.5],
        [0, 1.0, 0, 0, 0]
    ]),
    # lengths
    torch.cuda.LongTensor([
        [2],
        [3]
    ]),
    # masks
    torch.cuda.LongTensor([
        [2],
        [4]
    ])
)

# Use custom loss with backward impl
i1, t1, il1, m1 = ex
o1 = ma(i1, il1, m1)
loss1 = CategoricalMultiLabelLoss.apply(o1, t1)
print(loss1)
o1.retain_grad()
loss1.backward()
print(o1.grad)
for p in ma.parameters():
    print(p.grad)

i2, t2, il2, m2 = ex
o2 = mn(i2, il2, m2)
loss2 = torch.mean( # average CE across batch
    torch.div(
        torch.sum( # compute CE
            (-t2 * o2), dim=1
        ),
        (t2 > 0).sum(dim=1) # scaling factor
    )
)
print(loss2)
o2.retain_grad()
loss2.backward()
print(o2.grad)
for p in mn.parameters():
    print(p.grad)

This is the output:

tensor(1.2839, device='cuda:0', grad_fn=<CategoricalMultiLabelLossBackward>)
torch.Size([2, 1]) torch.Size([2, 5]) tensor([[2., 2., 2., 2., 2.],
        [1., 1., 1., 1., 1.]], device='cuda:0')
tensor([[-0.0000, -0.0000, -0.0000, -0.9301, -0.8291],
        [-0.0000, -0.8359, -0.0000, -0.0000, -0.0000]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 0.]], device='cuda:0')
tensor([0.], device='cuda:0')
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
tensor([ 0.9898, -0.4367,  0.2301, -0.5554, -0.2278], device='cuda:0')

tensor(1.2839, device='cuda:0', grad_fn=<MeanBackward0>)
tensor([[-0.0000, -0.0000, -0.0000, -0.1250, -0.1250],
        [-0.0000, -0.5000, -0.0000, -0.0000, -0.0000]], device='cuda:0')
tensor([[0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]], device='cuda:0')
tensor([[0., 0.]], device='cuda:0')
tensor([0.], device='cuda:0')
tensor([[0.],
        [0.],
        [0.],
        [0.],
        [0.]], device='cuda:0')
tensor([ 0.2957, -0.3807,  0.1376, -0.0131, -0.0396], device='cuda:0')