Hi everyone
I need to create a layer that samples data from a tensor that keeps the computation graph in order to backpropagate properly.
More precisely, I have objects represented as point-cloud from the ModelNet dataset and I have to draw some points from the cloud.
The layer I came up with is:
class Sampler(nn.Module):
def __init__(self, N: int):
super(Sampler, self).__init__()
# Number of points to extract
self.N = N
def forward(self, x):
r''' source: https://discuss.pytorch.org/t/take-random-sample-of-long-tensor-create-new-subset/36244/3
author: @rasbt
'''
rand_columns = torch.randperm(x.shape[1])[:self.N]
out = x.clone().detach()[:, rand_columns, :].requires_grad_(True)
return out
Where x is a tensor of shape [batch_size, n_points, n_features] and N < n_points
So the out tensor is of shape [batch_size, N, n_features]
Everything seems fine but I keep getting the UserWarning alert about copying tensors:
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)
even tho I tried to use the clone-detach on the x tensor. What am I missing here?
The source for the forward function is taken from a reply by @rasbt