Sample from tensor

Hi everyone :wave:t3:

I need to create a layer that samples data from a tensor that keeps the computation graph in order to backpropagate properly.
More precisely, I have objects represented as point-cloud from the ModelNet dataset and I have to draw some points from the cloud.

The layer I came up with is:

class Sampler(nn.Module):
    def __init__(self, N: int):
        super(Sampler, self).__init__()
        # Number of points to extract
        self.N = N
    def forward(self, x):
        r''' source:
        author: @rasbt
        rand_columns = torch.randperm(x.shape[1])[:self.N]
        out = x.clone().detach()[:, rand_columns, :].requires_grad_(True)
        return out

Where x is a tensor of shape [batch_size, n_points, n_features] and N < n_points
So the out tensor is of shape [batch_size, N, n_features]

Everything seems fine but I keep getting the UserWarning alert about copying tensors:
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor)
even tho I tried to use the clone-detach on the x tensor. What am I missing here?

The source for the forward function is taken from a reply by @rasbt

Also not sure why you’d get the warning. What you are doing seems to be right. Maybe it’s just a false positive warning, I dunno. What you could try is

out = torch.tensor(x.clone().detach()[:, rand_columns, :], requires_grad=True)

instead of

out = x.clone().detach()[:, rand_columns, :].requires_grad_(True)