Simple implemetation of Chamfer distance in PyTorch for 2D point cloud data

I was working on generative modelling on 2D point clouds. And I want to implement the Chamfer distance for loss function. But there was no function in PyTorch itself, but we can also make our own. Though I got one simple, similar implementation in numpy. but I want to make one that is compatible with GPU also and can back propagate, i.e. I want to make a loss function using the nn.Module. But I am having problems with the implementation. There are some implementation, but is not compatible with my data form, as the point clouds are belonging to PyG Data class. So there my point cloud X is not in form of (batch, N, 2) rather just B*N, 2. So importing those also will not work. So I wonder is there any sort of simple implementation in PyTorch in order to make the loss function. I also went through @pclucas 's doubt which was kinda similar, but it didn’t much work out. So Please anyone can please help me in this. Thank you.

I’m not sure I understand why the shape would be a problem. Could you explain the issue a bit more?

Generally, you would have to use PyTorch operations to calculate the loss in a differentiable way (I don’t know if PyG already provides it) or you would have to write a custom autograd.Function with the backward function.

Yeah, I was going to implement my own custom function, but then I found this library called chamferdist, and I used that library, and wrapped into the torch.nn.Module and used that as the loss function in my case.

I used some what like this …

class PointCloudLoss(nn.Module):
    def __init__(self, npoints):
        super().__init__() 
        self.cd = ChamferDistance()
        self.npoints = npoints

    def earth_mover_distance(self, y_true, y_pred):
        return torch.mean(torch.square(
            torch.cumsum(y_true, dim=-1) - torch.cumsum(y_pred, dim=-1)), dim=-1).mean()
    
    def forward(self, y_true, y_pred):
        if y_true.ndim != 3 and self.npoints is not None:
            self.batch = y_true.shape[0] // self.npoints
            y_true = y_true.view(self.batch, self.npoints, 2)
        
        if y_pred.ndim != 3 and self.npoints is not None:
            self.batch = y_true.shape[0] // self.npoints
            y_pred = y_pred.view(self.batch, self.npoints, 2)
    
        return  self.cd(y_true, y_pred, bidirectional=True) + self.earth_mover_distance(y_true, y_pred)

I used this reshaping coz, as we know this library takes the input in the form of B x N x c Where B is batch size, N is the number of point clouds, and c (here 2 as 2 co-ordinates as feature, I wanna take). As PyG generally multiples the batch in the form of (B*N, c) format, so during the loss computation the feature part is only to get reshaped and computed along side. And it kinda worked, though I did’t achieved my generative goal on this, haha. But yeah.