I am trying to develop a model which takes a sequence of word embedding (ts, dim)
as input and return several groups of word embeddings (max_size, dim)
, where ts
is the numbers of words, dim
is the dimension of word embedding, and max_size
is the max number of groups.
To achieve this, we simply apply a linear layer to the word embedding and project it to 3 dimensions, namely (0, 1, 2), where 0 means this word shouldn’t be retrieved, 1 means this word is a word inside of a retrieve span, and 2 means this word is the first token of a retrieve span. Then we will groupby the embedding with the above tags. Following is a simple illustration to this algorithm.
A pseudo code for the algorithm shows as follows:
>>> embs =[[ 1., 1., 1., 1.],
[ 2., 2., 2., 2.],
[ 3., 3., 3., 3.],
[ 4., 4., 4., 4.],
[ 5., 5., 5., 5.],
[ 6., 6., 6., 6.],
[ 7., 7., 7., 7.],
[ 8., 8., 8., 8.]] # assume we have 8 words
>>> tags = argmax(softmax(Linear(embds), -1)) # shape=(8, 1)
[2., 0, 2, 1, 1, 0, 0, 1]
>>> group_tags = [1., 0, 2, 2, 2, 0, 0, 0]
>>> group_embs = [[1., 1., 1., 1.], # the first group same as embs[0:1].mean()
[4, 4, 4, 4]] # the second group same as embs[2:5].mean()
I know that the argmax
operation is indifferentiable, so I replace torch.argmax
with soft_argmax
. I also get some helps from Generate groupby index for a tensor - #7 by Xuansheng_Wu and Groupby aggregate mean in pytorch - PyTorch Forums, so that I can get group_tags
and group_embs
with gradients.
However, I found that the Linear layer do not has gradient at all!
I provide a minimum reproducible code at here, hope it can save your time:
import torch as tc
def soft_argmax(embs):
labels = tc.tensor([0, 1, 2], dtype=embs.dtype, requires_grad=False, device=embs.device)
return (embs * labels).sum(axis=-1)
def group_by_index(a): # solution from https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335
b = a.unsqueeze(-2)
c = tc.transpose(b, -1, -2) @ b
x = tc.tril(tc.ones_like(c), -1) + tc.triu(c)
tmp = (x <= 0.6667).cumsum(axis=-1)
x[tmp>0] = 0
x, _ = tc.max(x, dim=-2)
x[x<1.3333] = 0
x[x>1.3333] = 1
vals = (a>=1.3333).cumsum(axis=-1)
return x * vals
def mean_by_label(samples, labels): #https://discuss.pytorch.org/t/groupby-aggregate-mean-in-pytorch/45335
labels = labels.view(labels.size(0), 1).expand(-1, samples.size(1)).long()
unique_labels, labels_count = labels.unique(dim=0, return_counts=True)
res = tc.zeros_like(unique_labels, dtype=samples.dtype)
res = res.scatter_add_(0, labels, samples)
return res / labels_count.float().unsqueeze(1)
class Mention(tc.nn.Module):
def __init__(self, emb_dim, max_mention):
super(Mention, self).__init__()
self.fc = tc.nn.Sequential(tc.nn.Linear(emb_dim, 3),
tc.nn.Softmax(-1))
self.max_size = max_mention
self.empty = tc.zeros((1, emb_dim), requires_grad=False)
def forward(self, batch_embs):
batch_tags = soft_argmax(self.fc(batch_embs)) # (bs, seq)
group_tags = group_by_index(batch_tags) # (bs, seq)
batch_attr_emb = []
for sample_tags, sample_embs in zip(group_tags, batch_embs):
group_embs = mean_by_label(sample_embs, sample_tags) # (num_tag, dim)
num_tag = group_embs.shape[0] - 1
group_embs = tc.vstack([group_embs] + [self.empty] * (self.max_size - num_tag))
batch_attr_emb.append(group_embs[1:self.max_size + 1])
return tc.stack(batch_attr_emb, axis=0)
def check_grad(self):
for n, p in self.named_parameters():
if p.requires_grad and "bias" not in n:
if not hasattr(p, "grad") or p.grad is None:
print("No gradient:", n)
def fake_compute_loss(self, y):
return y.mean()
if __name__ == "__main__":
model = Mention(emb_dim=4, max_mention=10)
emb = tc.tensor([[[float(i)] * 4 for i in range(1, 17)]] * 2, requires_grad=True)
pred = model(emb)
loss = model.fake_compute_loss(pred)
loss.backward()
model.check_grad()
If you download and run the above code, you will get No gradient: fc.0.weight
in your terminal. I expect a solution can pass this unittest.