Tversky Focal loss

I want to implement a custom loss function of a Unet model for HnE images and I made this so far, though I am not sure if I made any reasoning mistakes. Both my predictions and annotations are of the shape B C H W and my annotations have been one-hot encoded, where there is a 1 in the respective channel.

class Tversky_Focal_Loss(nn.Module):
    def __init__(self, device, weight=None, alpha=0.85, beta=0.15, gamma=3.0, epsilon=1e-7):
        super(Tversky_Focal_Loss, self).__init__()
        if weight is not None:
            self.weight = weight.to(device=device)
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.epsilon = epsilon
        self.device = device
    
    def forward(self, activations, annotations):
        activations = activations.float().to(self.device)
        annotations = annotations.float().to(self.device)

        # Softmax across channels to get class probabilities
        probabilities = torch.softmax(activations, dim=1)  # Shape: [B, C, H, W]
        probabilities = torch.clamp(probabilities, min=self.epsilon, max=1 - self.epsilon)

        # Tversky loss computation
        true_pos = (probabilities * annotations).sum(dim=(2, 3))  # Sum over spatial dimensions (H, W), resulting in [B, C]
        false_neg = ((1 - probabilities) * annotations).sum(dim=(2, 3))  # Sum over spatial dimensions (H, W), resulting in [B, C]
        false_pos = (probabilities * (1 - annotations)).sum(dim=(2, 3))  # Sum over spatial dimensions (H, W), resulting in [B, C]

        # Compute Tversky index per class
        tversky_index = (true_pos + self.epsilon) / (
            true_pos + self.alpha * false_neg + self.beta * false_pos + self.epsilon
        )  # Shape: [B, C]

        # Tversky loss per class, averaged over batch and classes
        tversky_loss = 1 - tversky_index  # Shape: [B, C]
        t_loss = tversky_loss.mean()  # Average over batch and classes

        # Focal loss computation
        pt = torch.where(annotations > 0, probabilities, 1 - probabilities)  # Shape: [B, C, H, W]
        focal_loss = -torch.pow(1 - pt, self.gamma) * torch.log(pt)

        # Apply class weights if provided
        if self.weight is not None:
            focal_loss *= self.weight.view(1, -1, 1, 1)  # Shape: [1, C, 1, 1]

        # Focal loss averaged over spatial dimensions and batch
        f_loss = focal_loss.mean(dim=(2, 3))  # Average over spatial dims (H, W) resulting in [B, C]
        f_loss = f_loss.mean()  # Average over batch and classes

        # Final loss combination
        total_loss = f_loss * 0.3 + t_loss * 0.7
        return total_loss