I want to implement a custom loss function of a Unet model for HnE images and I made this so far, though I am not sure if I made any reasoning mistakes. Both my predictions and annotations are of the shape B C H W and my annotations have been one-hot encoded, where there is a 1 in the respective channel.

```
class Tversky_Focal_Loss(nn.Module):
def __init__(self, device, weight=None, alpha=0.85, beta=0.15, gamma=3.0, epsilon=1e-7):
super(Tversky_Focal_Loss, self).__init__()
if weight is not None:
self.weight = weight.to(device=device)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.epsilon = epsilon
self.device = device
def forward(self, activations, annotations):
activations = activations.float().to(self.device)
annotations = annotations.float().to(self.device)
# Softmax across channels to get class probabilities
probabilities = torch.softmax(activations, dim=1) # Shape: [B, C, H, W]
probabilities = torch.clamp(probabilities, min=self.epsilon, max=1 - self.epsilon)
# Tversky loss computation
true_pos = (probabilities * annotations).sum(dim=(2, 3)) # Sum over spatial dimensions (H, W), resulting in [B, C]
false_neg = ((1 - probabilities) * annotations).sum(dim=(2, 3)) # Sum over spatial dimensions (H, W), resulting in [B, C]
false_pos = (probabilities * (1 - annotations)).sum(dim=(2, 3)) # Sum over spatial dimensions (H, W), resulting in [B, C]
# Compute Tversky index per class
tversky_index = (true_pos + self.epsilon) / (
true_pos + self.alpha * false_neg + self.beta * false_pos + self.epsilon
) # Shape: [B, C]
# Tversky loss per class, averaged over batch and classes
tversky_loss = 1 - tversky_index # Shape: [B, C]
t_loss = tversky_loss.mean() # Average over batch and classes
# Focal loss computation
pt = torch.where(annotations > 0, probabilities, 1 - probabilities) # Shape: [B, C, H, W]
focal_loss = -torch.pow(1 - pt, self.gamma) * torch.log(pt)
# Apply class weights if provided
if self.weight is not None:
focal_loss *= self.weight.view(1, -1, 1, 1) # Shape: [1, C, 1, 1]
# Focal loss averaged over spatial dimensions and batch
f_loss = focal_loss.mean(dim=(2, 3)) # Average over spatial dims (H, W) resulting in [B, C]
f_loss = f_loss.mean() # Average over batch and classes
# Final loss combination
total_loss = f_loss * 0.3 + t_loss * 0.7
return total_loss
```