Probabilities in multi-label classification are all ~0.5

Hi There,

I have a joint model with takes images and text as an input and produces two embeddings one for the image and the other for the text. Additionally, there is another branch of the network which takes the image embedding and via a single fully connected layer produces probabilities of certain classes existing in the image.

Model is defined as:

class ImageModel:
    def __init__(self, model):
        self.model = model # resnet model
        self.logit_size = None
    
    def change_logit_size(self, logit_size: int) -> None:
        self.logit_size = logit_size
        self.model.fc = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.model.fc.in_features, image_model.model.fc.out_features),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(self.model.fc.out_features, logit_size)
        )

class WordEmbedding(nn.Module):

    def __init__(self, n_words, embedding_dim):
        super(WordEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.word_embeddings = torch.nn.Embedding(n_words, embedding_dim, padding_idx=0)

    def forward(self, product_titles):
        summed_embeddings = []
        for title in product_titles:
            title_embeddings = self.word_embeddings(title)
            title_embeddings = title_embeddings.sum(dim=0)
            summed_embeddings.append(title_embeddings)
        return torch.stack(summed_embeddings)
    
class Ensemble(nn.Module):
    def __init__(self, embedding_dim, image_model, text_model, n_meta_categories):
        super(Ensemble, self).__init__()
        self.embedding_dim = embedding_dim
        self.image_model = image_model
        self.text_model = text_model
        
        # Takes the image embedding as an input and predicts the meta classes in the image
        self.attribute_prediction = nn.Sequential(OrderedDict([
            ('attribute-fc', nn.Linear(self.embedding_dim, n_meta_categories)),
        ]))

    def forward(self, x):
        images, title_tokens, meta_tokens = x
        title_tokens = title_tokens.long()
        embedding_text = self.text_model(title_tokens)
        embedding_image = self.image_model(images)
        attribute_prediction = self.attribute_prediction(embedding_image)
        return embedding_image, embedding_text, attribute_prediction

The training process looks like this:

# A train class is defined earlier and provides additional functionality

bce_loss = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001)

dataloader = DataLoader(dataset_train, batch_size=256, shuffle=True, num_workers=2)

for _ in range(50):

    for batch in dataloader:

        optimizer.zero_grad()

        # meta_cats_batch = multi-label categories
        image_batch, tokens_batch, meta_cats_batch = batch
        inputs = (image_batch.to(DEVICE), tokens_batch.to(DEVICE), meta_cats_batch.to(DEVICE))

        with torch.set_grad_enabled(True):
            image_embeddings, text_embeddings, meta_preds = model(inputs)

        cosine_matrix = train.cosine_distance(image_embeddings, text_embeddings)

        mbmr_loss = train.calculate_loss(cosine_matrix, temperature=0.025)
        mbmr_loss /= image_batch.size(0)
        meta_cats_loss = Variable(bce_loss(meta_preds.detach().cpu(), meta_cats_batch), requires_grad=True)

        loss = mbmr_loss + meta_cats_loss
        loss.backward()
        optimizer.step()

        print('Meta Category BCE Loss:', meta_cats_loss.item())

The mbmr works as expected and reduces as the model trains. The issue lies with the meta_cats_loss. After 50 epochs on training (intentionally overfitting on a small dataset) the loss remains static e.g.:

# Last 5 epochs and all previous epochs are around 0.69
Meta Category BCE Loss: 0.6912458344473028
Meta Category BCE Loss: 0.6919576254937662
Meta Category BCE Loss: 0.6923714890474508
Meta Category BCE Loss: 0.6920379958767525
Meta Category BCE Loss: 0.692357156795131

Now if I calculate the probabilities from the logits defined in meta_preds via taking the sigmoid of meta_preds I get:

tensor([[0.4960, 0.4946, 0.5037, 0.4950],
        [0.4985, 0.5009, 0.4980, 0.4999],
        [0.4929, 0.4974, 0.5000, 0.4961],
        [0.5034, 0.4920, 0.4988, 0.5023],
        [0.4997, 0.5015, 0.4996, 0.5013],
        [0.5010, 0.4978, 0.5017, 0.4963],
        [0.4944, 0.4953, 0.4961, 0.5004],
        [0.5005, 0.4998, 0.5054, 0.5016],
        [0.5059, 0.5084, 0.4994, 0.5007],
        [0.4963, 0.4975, 0.5027, 0.4993],
        [0.4954, 0.4934, 0.4999, 0.4949],
        [0.4993, 0.5061, 0.5056, 0.5055],
        [0.4992, 0.5008, 0.5003, 0.4957],
        [0.4986, 0.4959, 0.5012, 0.5054],
        [0.4982, 0.5096, 0.5024, 0.5014],
        [0.4959, 0.5086, 0.4982, 0.4996],
        [0.5014, 0.5020, 0.5012, 0.5015],
        [0.4961, 0.5028, 0.4989, 0.4993],
        ...
        [0.4987, 0.4968, 0.4968, 0.4980],
        [0.4993, 0.4999, 0.4976, 0.4995],
        [0.5010, 0.5011, 0.5015, 0.4983],
        [0.4954, 0.4964, 0.4975, 0.5037],
        [0.4891, 0.4942, 0.4993, 0.4979],
        [0.5000, 0.5086, 0.4978, 0.4972],
        [0.4979, 0.4957, 0.5000, 0.4980],
        [0.4976, 0.5014, 0.4995, 0.5001],
        [0.5033, 0.4998, 0.5012, 0.4981],
        [0.5024, 0.5009, 0.4998, 0.5004],
        [0.5054, 0.4989, 0.5054, 0.5027]], device='cuda:0',
       grad_fn=<SigmoidBackward>)

All the probabilities are around 0.5 and I cannot understand why?

Hi Harpal!

This detach() is wrong:

bce_loss(meta_preds.detach().cpu(), meta_cats_batch)

It “breaks the computation graph” in that it “detaches” meta_preds
from the computation of bce_loss() so that you don’t backpropagate
through, and optimize, the weights that were used to produce
meta_preds.

That is, meta_cats_loss has no effect on your training.

(Note, Variable has been deprecated since a long time ago, so you
don’t need or want it. Also, the requires_grad=True in your Variable
doesn’t undo the damage done by detach().)

I haven’t looked at the rest of your code, but this issue is a good place
to start.

Best.

K. Frank

Hi @KFrank,

Thank you for your swift reply. You’re right the detach() was causing the issue. It’s all working as expected now.

Thanks for helping me debug!