Hi, so I’m trying to implement the FedProx loss from this paper, my idea is to take the BCE loss and edit like this:
def fedprox_loss(output, target, new_weights, original_weights):
rss = nn.BCELoss(output, target)
omega = torch.sum(torch.pow(new_weights - old_weights, 2))
return rss + omega
But i don’t know if this works properly and i really don’t know if I need to implement the backward method and how to do it.