Can i use cross entropy loss as a binary loss?

Can I use cross entropy loss (CrossEntropyLoss) instead of (BCELoss) for the case that my labels are binary labels (0,1)?

I appreciate some explanation and intuition on that matter.

Thanks

1 Like

Yes, this would be possible. You would need to change some shapes and the dtype of your target.
Here is a small example with some comments:

# CrossEntropyLoss
class MyCEModel(nn.Module):
    def __init__(self):
        super(MyCEModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 2) # two output neurons
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = MyCEModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
data = torch.randn(5, 10)
# target has no channel dimension and dtype=torch.long
target = torch.randint(0, 2, (5,), dtype=torch.long)

output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

# BCELoss
class MyBCEModel(nn.Module):
    def __init__(self):
        super(MyBCEModel, self).__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 1) # single output neuron
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = MyBCEModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# target has same shape and dtype as output
target = torch.randint(0, 2, (5, 1),  dtype=torch.float32)

output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

As you can see, the main difference is that MyCEModel will have twice the number of parameters in the last linear layer compared to MyBCEModel.

2 Likes

just the target? or also the output?

That goes back to the point that we need to have output of shape NxC (C is number of classes) for CEL but for BNE it should be just N, am i understanding correct?

Also, BCEWithLogitsLoss can deal with unbalance data, is there any version of CEL that has that features as well?

Also, Can you please help me understand the way that you define the targets. im confuse on a couple of things about them.