# Can i use cross entropy loss as a binary loss?

Can I use cross entropy loss (CrossEntropyLoss) instead of (BCELoss) for the case that my labels are binary labels (0,1)?

I appreciate some explanation and intuition on that matter.

Thanks

1 Like

Yes, this would be possible. You would need to change some shapes and the `dtype` of your `target`.
Here is a small example with some comments:

``````# CrossEntropyLoss
class MyCEModel(nn.Module):
def __init__(self):
super(MyCEModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2) # two output neurons

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

model = MyCEModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
data = torch.randn(5, 10)
# target has no channel dimension and dtype=torch.long
target = torch.randint(0, 2, (5,), dtype=torch.long)

output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()

# BCELoss
class MyBCEModel(nn.Module):
def __init__(self):
super(MyBCEModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 1) # single output neuron

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

model = MyBCEModel()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# target has same shape and dtype as output
target = torch.randint(0, 2, (5, 1),  dtype=torch.float32)

output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
``````

As you can see, the main difference is that `MyCEModel` will have twice the number of parameters in the last linear layer compared to `MyBCEModel`.

2 Likes

just the target? or also the output?

That goes back to the point that we need to have output of shape NxC (C is number of classes) for CEL but for BNE it should be just N, am i understanding correct?

Also, `BCEWithLogitsLoss` can deal with unbalance data, is there any version of CEL that has that features as well?

Also, Can you please help me understand the way that you define the targets. im confuse on a couple of things about them.