I have a simple model that i’m trying to obtain evaluation metrics (segmentation task, originally minimizing binary cross entropy loss). However, given the unbalanced nature of my data, i decided to implement a Focal Loss function. Given it is a loss, do i have to also backprop through these weights? I call loss.backward on my BCE, but i suppose i’m looking for some intuition as to why one would have two loss functions in their problem, if i should just focus on one and remove the other, or if there is a way to properly incorporate both. My code implementation is below, first the weighted focal loss function, then my training function.
def weighted_focal_loss(outputs, targets, alpha, gamma):
alpha = torch.tensor([alpha, 1-alpha])
gamma = gamma
BCE_loss = F.binary_cross_entropy_with_logits(outputs, targets, reduction='none').squeeze(dim=1).flatten()
targets = targets.type(torch.long)
at = alpha.gather(0, targets.data.view(-1))
pt = torch.exp(-BCE_loss).flatten()
F_loss = at*(1-pt)**gamma * BCE_loss
return F_loss.mean()
def train_val(model, loss, opt, scheduler, epochs):
train_ious = []
train_acc = []
train_loss = []
train_fl = []
for epoch in range(epochs):
model.train()
running_loss = 0
running_acc = 0
running_iou = 0
running_fl =0
for i, batch in tqdm(enumerate(train_dl), desc='training'):
x = batch['img'].float()
y = batch['fpt'].float().unsqueeze(dim=1)
output = model(x)
running_acc += pixel_accuracy(output, y)
running_iou += mIoU(output,y)
running_fl += weighted_focal_loss(output,y, 0.7, 3)
loss_epoch = loss(output, y)
running_loss += loss_epoch.item()
opt.zero_grad()
loss_epoch.backward(retain_graph=False)
opt.step()