Computation graph integrity

I’m having a technical question. Let’s consider the following example where I want to select for only the top 10 values from the last dimension. The example works as intended but I am worried that I might break the computation graph since I’m creating a new tensor in line 2. Can anyone confirm if this is the right approach, or how I could verify it?

 topk, indices = torch.topk(support, 10)
 support = torch.zeros([B, N, N], device=self.device).scatter_(2, indices, topk)

Hi,

You can use this is a toy example to see which gradient are propagated back to support with something like:

support = torch.rand(20, requires_grad=True)
indices = # Whatever they should be

topk, indices = torch.topk(support, 10)
out = torch.zeros([B, N, N], device=self.device).scatter_(2, indices, topk)

out.sum().backward()
print(support.grad)
# That will give you a 1 for each entry that was selected by the topk

But the gradients will propagate as you expect here, no problem !

1 Like