Why there is no gradient for a Linear layers?

I am trying to develop a model which takes a sequence of word embedding (ts, dim) as input and return several groups of word embeddings (max_size, dim), where ts is the numbers of words, dim is the dimension of word embedding, and max_size is the max number of groups.

To achieve this, we simply apply a linear layer to the word embedding and project it to 3 dimensions, namely (0, 1, 2), where 0 means this word shouldn’t be retrieved, 1 means this word is a word inside of a retrieve span, and 2 means this word is the first token of a retrieve span. Then we will groupby the embedding with the above tags. Following is a simple illustration to this algorithm.

A pseudo code for the algorithm shows as follows:

>>> embs =[[ 1.,  1.,  1.,  1.],
           [ 2.,  2.,  2.,  2.],
           [ 3.,  3.,  3.,  3.],
           [ 4.,  4.,  4.,  4.],
           [ 5.,  5.,  5.,  5.],
           [ 6.,  6.,  6.,  6.],
           [ 7.,  7.,  7.,  7.],
            [ 8.,  8.,  8.,  8.]] # assume we have 8 words 
>>> tags = argmax(softmax(Linear(embds), -1)) # shape=(8, 1)
[2., 0, 2, 1, 1, 0, 0, 1]
>>> group_tags = [1., 0, 2, 2, 2, 0, 0, 0]
>>> group_embs = [[1., 1., 1., 1.], # the first group same as embs[0:1].mean()
                  [4, 4, 4, 4]] # the second group same as embs[2:5].mean()

I know that the argmax operation is indifferentiable, so I replace torch.argmax with soft_argmax. I also get some helps from Generate groupby index for a tensor - #7 by Xuansheng_Wu and Groupby aggregate mean in pytorch - PyTorch Forums, so that I can get group_tags and group_embs with gradients.

However, I found that the Linear layer do not has gradient at all!

I provide a minimum reproducible code at here, hope it can save your time:

import torch as tc


def soft_argmax(embs):
    labels = tc.tensor([0, 1, 2], dtype=embs.dtype, requires_grad=False, device=embs.device)
    return (embs * labels).sum(axis=-1)


def group_by_index(a): # solution from https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335
    b = a.unsqueeze(-2)
    c = tc.transpose(b, -1, -2) @ b
    x = tc.tril(tc.ones_like(c), -1) + tc.triu(c)
    tmp = (x <= 0.6667).cumsum(axis=-1)
    x[tmp>0] = 0
    x, _ = tc.max(x, dim=-2)
    x[x<1.3333] = 0
    x[x>1.3333] = 1
    vals = (a>=1.3333).cumsum(axis=-1)
    return x * vals


def mean_by_label(samples, labels): #https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335
    labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1)).long()
    unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
    res = tc.zeros_like(unique_labels, dtype=samples.dtype)
    res = res.scatter_add_(0, labels, samples)
    return res / labels_count.float().unsqueeze(1)


class Mention(tc.nn.Module):
    def __init__(self, emb_dim, max_mention):
        super(Mention, self).__init__()
        self.fc = tc.nn.Sequential(tc.nn.Linear(emb_dim, 3),
                                   tc.nn.Softmax(-1))
        self.max_size = max_mention
        self.empty = tc.zeros((1, emb_dim), requires_grad=False)

    def forward(self, batch_embs):
        batch_tags = soft_argmax(self.fc(batch_embs)) # (bs, seq)
        group_tags = group_by_index(batch_tags)  # (bs, seq)
        batch_attr_emb = []
        for sample_tags, sample_embs in zip(group_tags, batch_embs):
            group_embs = mean_by_label(sample_embs, sample_tags) # (num_tag, dim)
            num_tag = group_embs.shape[0] - 1
            group_embs = tc.vstack([group_embs] + [self.empty] * (self.max_size - num_tag))
            batch_attr_emb.append(group_embs[1:self.max_size + 1])
        return tc.stack(batch_attr_emb, axis=0)

    def check_grad(self):
        for n, p in self.named_parameters():
            if p.requires_grad and "bias" not in n:
                if not hasattr(p, "grad") or p.grad is None:
                    print("No gradient:", n)

    def fake_compute_loss(self, y):
        return y.mean()


if __name__ == "__main__":
    model = Mention(emb_dim=4, max_mention=10)
    emb = tc.tensor([[[float(i)] * 4 for i in range(1, 17)]] * 2, requires_grad=True)
    pred = model(emb)
    loss = model.fake_compute_loss(pred)
    loss.backward()
    model.check_grad()
        

If you download and run the above code, you will get No gradient: fc.0.weight in your terminal. I expect a solution can pass this unittest.

Your model is partially working and the input (emb) will get proper gradients.
However, the parameters of self.fc won’t get valid gradients, as you are transforming its outputs to a LongTensor, which will break the computation graph in mean_by_label.

Hi ptrblck, thanks to your response!

Can you give me any suggestions to implement this function? Since scatter_add function expect integers as inputs, I can’t simply remove .long() from mean_by_label.

It seems you want to use the floating point outputs of self.fc as an index, which won’t be differentiable. I’m not familiar enough with your use case to completely understand what you are trying to achieve.

Thank you ptrblck for your advice!
I will rethink about my model.