My framework contains two parts: 1). a selection mechanism 2). a predictor. The input data is a sequence but not all items in the sequence are useful. So I just use TopK items in the sequence to predict something.

But I failed to calculate the gradients since using top K indices block gradients flow to the selection model. (second part of the code)

A toy code as following:

seq = torch.randn(10, 15, 8) # batch * seq_length * feature
sel_model = nn.Linear(8, 1) # the selection mechanism
scores = sel_model(seq).squeeze()
# using topK indices bring the autograd problem
val, ind = torch.topk(scores, 3, dim=1, largest=True)
ind_exp = ind.unsqueeze(-1).expand(10, 3, 8)
selected_seq = torch.gather(seq, 1, ind_exp)
feature = selected_seq.mean(-1)
pred_model = nn.Linear(3, 1)
pred = pred_model(feature).squeeze()
loss = pred.sum() # assume there is a loss
loss.backward()
print('sel_model grad：', sel_model.weight.grad) # grad is None
print('select scores grad：', scores.grad) # grad is None
print('pred_model grad：', pred_model.weight.grad)

So how to allow the gradients back to the selection model?

Yeah, well, the core of the problem is that you are not using scores or val in the further calculation. If you then do the calculation based on the input sequence, you cannot expect gradients (which factor through dloss/dval and dloss/dscores = dval/dscores dloss/dval) to flow back. (You need to use scores.retain_grad() to keep intermediate gradients, btw.)

Now, if you consider only the indices as the output of your selection model, you could look towards learning techniques in reinforcement learning for discrete decisions.

However, I’m still wondering what’s the correct way to use top-k values, as top-k will break the chain rule (not sure just in my opinion)? Could you provide some examples?

Below is a modified version of the previous code, using top-k values.

data = torch.randn(10, 15) # batch * features
select_model = nn.Linear(15, 15) # each feature has a score
scores = select_model(data)
# use top-3 features and mask the rest
val, ind = torch.topk(scores, 3, dim=1, largest=True)
masked_scores = torch.zeros_like(scores)
masked_scores.scatter_(1, ind, val)
masked_data = data * masked_scores # mask some sequences
# I can use retain_grad to keep grad, but which one I should use?
# scores.retain_grad()
# masked_scores.retain_grad()
# masked_data.retain_grad()
# do the prediction task using the masked data
pred_model = nn.Linear(15, 1)
pred = pred_model(masked_data).squeeze()
loss = pred.sum() # assume there is a loss
loss.backward()
print('select_model grad：', select_model.weight.grad.size())
print('select scores grad：', scores.grad) # grad is None
print('masked scores grad：', masked_scores.grad) # grad is None
print('pred_model grad：', pred_model.weight.grad.size())

topk will propagate gradients from the grad of output values to the grad of inputs all right (in fact, the backward is a scatter to the indices similar to the ones you have), so the topk part you show should work now.