Converting single label classification to multi-label classification

I am currently using a LSTM model to do some binary classification on a text dataset and was wondering how to go about extending this model to perform multi-label classification. The current model is as follows:

class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, bidirectional,
                 dropout_rate, pad_index):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_index)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, bidirectional=bidirectional,
                            dropout=dropout_rate, batch_first=True)
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, ids, length):

        embedded = self.dropout(self.embedding(ids))

        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, length, batch_first=True, 
                                                            enforce_sorted=False)
        packed_output, (hidden, cell) = self.lstm(packed_embedded)

        output, output_length = nn.utils.rnn.pad_packed_sequence(packed_output)

        if self.lstm.bidirectional:
            hidden = self.dropout(torch.cat([hidden[-1], hidden[-2]], dim=-1))

        else:
            hidden = self.dropout(hidden[-1])

        prediction = self.fc(hidden)

        return prediction

Currently, the output of the model is 2 dimensions as the labels can either be 1 or 0. But I would be looking to convert it to classify a one hot encoding of the form [0, 0, 1, 0, 0, 1, 0]. I am aware I would probably need to change the criterion from nn.crossEntropyLoss but I am unsure what other steps to take.

Hi Kerrangcash!

When you say “multi-label” I assume that you mean “multi-label, multi-class”
classification. That is, you have multiple classes, presumably, but not
necessarily, more than two. And each sample can be labelled with any
number of those several classes, including none or all.

First of all, the example vector you gives represents not “one-hot encoding,”
but *multi-hot encoding." (A one-hot encoded vector would contain exactly
one 1 with all the rest 0s.)

Your example vector has length 7. Therefore I assume that you have 7
classes (which should be that value of output_dim).

No. If I am correct that you are performing a multi-label, multi-class
classification, then you should use BCEWithLogitsLoss. The point is
that a multi-label problem is, in some sense, n-class binary-classification
problems run through your network at the same time. (Is your sample in
“class-1” – yes or no? Is it separately in “class-2” – yes or no? And so on.)

So you would want the output of your model to be a final Linear layer
with out_features = out_dim (so, for this example, 7), as you have
in the code you posted. Feed these logits directly to BCEWithLogitsLoss
without any intervening “activation” layer.

The target (ground-truth labels) you pass to BCEWithLogitsLoss can
be your multi-hot-encoded label vector, except that it has to be float,
rather than int. (This is because BCEWithLogitsLoss accepts
probabilistic targets that can range from 0.0 to 1.0, but having them
be exactly 0.0 or 1.0 is perfectly fine.)

Best.

K. Frank