I am using the following code snippet for focal loss for binary classification on the output of vision transformer. Vision Transformer in my case throws two values as output. So, I used a sigmod of difference of the two outputs as follows below. Could you please confirm if it is correct?
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss) # prevents nans when probability 0
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
return F_loss.mean()
criterion = FocalLoss()
m = nn.Sigmoid()
and then inside the train:
if train:
print('training...')
torch.autograd.set_detect_anomaly(True)
for i_batch, sample_batched in enumerate(dataloader_train):
#pdb.set_trace()
feats = torch.stack(sample_batched['image'])
labels = torch.as_tensor(sample_batched['label']).cuda()
print('feats shape: ', feats.shape)
print('labels shape: ', labels.shape)
output = model(feats)
loss = criterion(m(output[:,1]-output[:,0]), labels.float())
#loss = criterion(output, labels)
print('train loss is: ', loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
acc = (output.argmax(dim=1) == labels).float().mean()
train_preds = output.argmax(dim=1)
I have this:
train_epoch_accuracy: 0.84375
not test
Evaluating...
epoch is: 49
evaluating...
epoch val acc: tensor(0.8541, device='cuda:0')
val_epoch_accuracy: 0.8426966292134831
best val acc: tensor(0.8541, device='cuda:0')
best epoch: 0
best preds: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
best val labels: [0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
As you see all predicted values are 0 which is the majority class (class 0 is 84% of the data).