I’m having a technical question. Let’s consider the following example where I want to select for only the top 10 values from the last dimension. The example works as intended but I am worried that I might break the computation graph since I’m creating a new tensor in line 2. Can anyone confirm if this is the right approach, or how I could verify it?
topk, indices = torch.topk(support, 10)
support = torch.zeros([B, N, N], device=self.device).scatter_(2, indices, topk)
You can use this is a toy example to see which gradient are propagated back to support with something like:
support = torch.rand(20, requires_grad=True)
indices = # Whatever they should be
topk, indices = torch.topk(support, 10)
out = torch.zeros([B, N, N], device=self.device).scatter_(2, indices, topk)
out.sum().backward()
print(support.grad)
# That will give you a 1 for each entry that was selected by the topk
But the gradients will propagate as you expect here, no problem !