Maintaining gradient computation graph through input subset

Hello,

I’m not sure how to best formulate this issue - I have a model that does something like this:

def forward(self, xyz):  # xyz is (B, 3, N)
    # a bunch of stuff that extract features
    # ....
    x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))  # x: (B, 128, N) N= number of xyz points

Usually this would end in a semantic segmentation, but I’m not interested in that - I’m actually interesting in picking a unique subset of xyz with exactly a predetermined num_points - so the output would be shaped like (B, 3, num_points).

I’m interested in using the learned extracted features, but I’m not sure how I can pick the subset of the points without disconnecting the gradient. One example of something I tried was:

# self.weights = nn.Conv1d(128, 1, 1) 
scores = self.weights(x)  # # Shape: [B, 1, N]
scores = torch.softmax(weights, dim=-1) 

_, top_indices = torch.topk(weights.squeeze(1), self.num_points, dim=-1)  #  num points is the size of the subset

Using the indices I can then pick the subset, but this obviously disconnects the computational graph… Is there a way to sort of “force” the learned features to translate into a unique, fix-sized subset of the input features?

I hope I make sense. At the end, I expect to pass model(xyz) and get a unique hard subset of xyz shaped like (B, 3, num_points).

Thank you for any tips :slight_smile:

This isn’t the case if you select from a differentiable output:

x = torch.randn(1, 1)
lin = nn.Linear(1, 10)

out = lin(x)

out_selected, indices = torch.topk(out, k=3, dim=1)
print(out_selected.shape)
# torch.Size([1, 3])
print(out_selected.grad_fn)
# <TopkBackward0 object at 0x7f4ce3ec5210>
print(indices)
# tensor([[1, 7, 0]])

out_selected.mean().backward()
print(lin.weight.grad)
# tensor([[-0.0697],
#         [-0.0697],
#         [ 0.0000],
#         [ 0.0000],
#         [ 0.0000],
#         [ 0.0000],
#         [ 0.0000],
#         [-0.0697],
#         [ 0.0000],
#         [ 0.0000]])

In your example code you are also calling torch.topk(weights...) while in your initial example you want to process the xyz input to the model, not the trainable weight, so unsure what exactly your use case is.

Hey ptrblck!
Thanks for the reply.

I meant using the indices, rather than the out_selected as in your reply
I’m not sure if my goal here is actually feasible, perhaps you would know a little better than me.

I’ll try explaining my expected input/output

Input: 3D points, lets call these xyz
Output: Exactly n 3D points which are a hard & unique subset of xyz, lets call these out
This is basically what I want the model to learn.

My model extracts features from the xyz points through various techniques. For clarity, lets call these features and they are shaped (B, 128, N).

I want to use features in combination with the input xyz to produce out.

Edit:
When I mentioned the “trick” that fails, with the weights layer - I attempted to do something like this:

scores = self.weights(features)  # # self.weights = nn.Conv1d(128, 1, 1) 
scores = scores.squeeze(1)

_, indices = torch.topk(F.softmax(scores, dim=1), 30, dim=-1, largest=True, sorted=True)
selected_points = torch.gather(xyz, 2, indices.unsqueeze(1).expand(-1, 3, -1))  # (B, 3, 30) - This is "out"

I’m still missing details of your use case since selected_points are still differentiable:

import torch
import torch.nn as nn
import torch.nn.functional as F

weights = nn.Conv1d(128, 1, 1) 
features = torch.randn(1, 128, 100)

scores = weights(features)  # # self.weights = nn.Conv1d(128, 1, 1) 
scores = scores.squeeze(1)

_, indices = torch.topk(F.softmax(scores, dim=1), 30, dim=-1, largest=True, sorted=True)

xyz = torch.randn(1, 3, 100, requires_grad=True)
selected_points = torch.gather(xyz, 2, indices.unsqueeze(1).expand(-1, 3, -1))  # (B, 3, 30) - This is "out"

print(selected_points.shape)
# torch.Size([1, 3, 30])
print(selected_points.grad_fn)
# <GatherBackward0 object at 0x7ff13ddf9d80>

selected_points.mean().backward()

print(xyz.grad)
# tensor([[[0.0000, 0.0111, 0.0111, 0.0000, 0.0000, 0.0000, 0.0111, 0.0111,
#           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0111, 0.0111, 0.0000,
#           ...

and I don’t know if you want to learn the indices, which is not directly possible since integer values would have a gradient of zero everywhere and Inf/NaN at the rounding points.