backward() takes 4 positional arguments but 5 were given

I need help with my .backward() method as I don’t know how to fix it. I believe it is due to the custom loss function taking as input a shape of (1001, 1, 3) (where 1001 is the batch_size) and outputting a shape of (1) (as the loss function acts over the batches given), but I can’t be sure. Is there a way to fix the error of TypeError: backward() takes 4 positional arguments but 5 were given? My code is below:

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
# model is designed to have input (batch_size, 1, 3) and output (batch_size, 1, 3)
# I would like the loss function to consider elements of the batch given, too
# loss_fn take (x, y), each of shape (batch_size, 1, 3) 
# and outputs (d1, d2, idx1, idx2) which are (float, float, int, int)

def train_loop(model, loss_fn, optimizer, epochs=1000, source=source_shape, target=target_shape):
    # source_shape = target_shape = (1001, 1, 3)
    for epoch in range(epochs):
        d1, d2, idx1, idx2 = loss_fn(model(source), target)
        loss = (torch.mean(d1)) + (torch.mean(d2))
        # backprop
        loss.backward() # <- here is where the error is

The loss function is Chamfer’s distance, taken from a Github chamfer_distance library. How it’s computed is that for two point clouds Sx Sy, (say, each of shape n, 1, 3), for each point in Sx find the minimal L2 distance to any point in Sy, summing this over points in Sx (this is how d1 is calculated in my code above). The same is then done for points in Sy, and the two sums are added to get the final loss.

Could you explain the implementation of the custom loss function in more detail or post its definition, please?

My post has just been edited to correct for this.