[HELP] FedProx loss implementation

Hi, so I’m trying to implement the FedProx loss from this paper, my idea is to take the BCE loss and edit like this:

def fedprox_loss(output, target, new_weights, original_weights):
    rss = nn.BCELoss(output, target)
    omega = torch.sum(torch.pow(new_weights - old_weights, 2))
    return rss + omega

But i don’t know if this works properly and i really don’t know if I need to implement the backward method and how to do it.

If all your inputs are PyTorch tensors, you won’t have to implement the backward pass manually and Autograd will do it for you.
Are you seeing any errors or unexpected behavior?

You forgot to multiply by mu.