nn.TransformerEncoder for classification

Hello all,

I’m trying to get the built-in pytorch TransformerEncoder to do a classification task; my eventual goal is to replicate the ToBERT model from this paper (paperswithcode is empty). Unfortunately, my model doesn’t seem to learn anything.

import torch.nn as nn

class Net(nn.Module):
    def __init__(
        self,
        embeddings,
        nhead=8,
        nhid=200,
        num_layers=2,
        dropout=0.1,
        classifier_dropout=0.1,
        max_len=256,
    ):

        super().__init__()

        self.d_model = embeddings.size(1)
        assert (
            self.d_model % nhead == 0
        ), "nheads must divide evenly into d_model"

        self.emb = nn.Embedding.from_pretrained(embeddings, freeze=False)
        self.pos_encoder = PositionalEncoding(
            self.d_model, dropout=dropout, vocab_size=embeddings.size(0)
        )

        encoder_layers = nn.TransformerEncoderLayer(
            d_model=self.d_model, nhead=nhead, dim_feedforward=nhid, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layers, num_layers=num_layers
        )
        self.classifier = nn.Sequential(
            # Other layers to go here if needed once things seem to be working
            nn.Linear(self.d_model, 2),
        )

    def forward(self, x):
        x = self.emb(x) * math.sqrt(self.d_model)

        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)  # self.src_mask)
        x = x.mean(dim=1)
        return self.classifier(x)
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, vocab_size=5000, dropout=0.1, batch_size=100):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

The PositionalEncoding layer is taken almost directly from the pytorch language modeling example, with the exception of changing dimensions to match my preference for batch_first=True.

There are few similar posts, all without definite answers.

I found a couple of examples of transformers for classification:

Both of these seem to work with good accuracy, so I’m sure it’s possible, but both also seem to build the transformer “from scratch.” I’d like to figure out why I can’t get it to work with the pytorch TransformerEncoder.

When I run, my loss even on my training set doesn’t go anywhere, so it’s clearly just not learning. I’ve tried going through the PositionalEncoding layer a few times, since that’s where much of the complexity lies and even tried replacing it with the positional encoding strategies used in the libraries above – no difference.

Does anyone see something I’m doing obviously wrong? Am I mistaken that I should be able to use a TransformerEncoder for classification in this way?

Many thanks in advance!

I found another example of someone trying to use nn.TransformerEncoder for sequences classification – unfortunately their model doesn’t seem to be learning anything either, accuracy on IMDB is 53% on the training set. I contacted the author who thinks it may be attributable to suboptimal hyperparameters, I but I’m not so sure. I’ve tried with similar hyperparameters to what the former example above uses with no improvement.

Speaking of which, I was reading through @pbloem’s excellent article on transformers and realized that he was the author of former, which I’ve been using as an example for a few weeks, and also realized that he seems to be registered on this forum. If you happen to see this post and can spare a minute, and if you can immediately see an obvious reason why nn.TransformerEncoder is failing at classification where your former library does great, I’d be truly grateful if you could point me in the right direction.

EDIT: Also tried asking at r/learnmachinelearning, figured I might as well leave a link in case that gets a response that could help someone.

Hi n8henrie,

Thanks for the kind words. I had a look at your code, and I don’t see any obvious problems with it. The only difference with my version is that I use positional embeddings rather than encodings, but I see you already tried using position embeddings.

If I were in your situation, I’d probably try the following:

  • Check that the embedding parameters are getting gradients. You can do this by setting w.retain_grad() on the weight tensor, and then printing the values just after the backward pass.
  • Check the loss. Your model output is linear, so the loss should apply a softmax (i.e. use cross entropy loss).
  • Try it with 0 transformer layers (i.e. just train word embeddings). IMDb is simple enough that that should put you well over chance. You can compare to former with 0 layers to see what performance you can expect.
  • Disable the position encoding. The model should still be able to get some performance, without any position information. If the performance goes up with the positions disabled, you know that the problem is somewhere in the position encodings.

You may have tried all of this already, but this is where I’d start.

good luck,
Peter

2 Likes

Thanks so much for your kind words and suggestions.

I’ve made a lot of progress and my model is now learning. I’m not 100% clear what made the difference, but your recommendations were very helpful in sorting things out.

Check that the embedding parameters are getting gradients

They were indeed getting gradients.

Check the loss. Your model output is linear, so the loss should apply a softmax (i.e. use cross entropy loss).

I’m using nn.CrossEntropyLoss, which includes a softmax.

Disable the position encoding

What I did instead was to swap my position encoding implementation into former, and it didn’t hurt its learning.

Try it with 0 transformer layers (i.e. just train word embeddings).

This was a great recommendation and was really informative. Both former and my model learned well with 0 transformer layers. Then I found that my model would learn okay with 1 or 2 layers, but at 3 was having trouble. Eventually I found that decreasing my learning rate from 0.001 to 0.0001 and extending my training time from 10 to 50 epochs showed that it was eventually learning. With some tinkering with the gradient clipping, the LambdaLR scheduler, and the initial learning rate, I now have its accuracy on the training set eventually reaching high 90s. I know I had taken learning rate down to 1e-5 previously to ensure that wasn’t the issue; my guess is that there were multiple problems and each one masked seeing the improvement when I tinkered with the other.

Now that I can see that my model is capable of learning I consider this issue solved. Thank you so much!

I will update soon to post my actual model in case it’s helpful for future readers.

4 Likes

Hi, I’d like to add that I’ve also tried the built in PyTorch TransformerEncoder class for a classification task, it seems to learn nothing, whereas this one by lucidrains works fine. Is there a bug in the native implementation? Or perhaps some missing features that the others may have?

Hi @n8henrie thanks for having tackled this issue already. I was wondering if your code above considers only the case in which padded tokens are not needed? If so in order to extend it to work for padded tokens, I would need to change the mean operation right such that it ignores the encoder embeddings of the padded tokens, right?

By “soon” I apparently meant a year. Sorry! I wrote up my code and results here, hope it’s helpful:

Good question, I’m not really sure.

It looks like the torchtext data does contain padding (see below, scroll to the end of the output), and it’s possible that masking (perhaps with src_key_padding_mask?) would improve performance? Perhaps my example code is just learning to essentially ignore <pad> tokens?

>>> " ".join(TEXT.vocab.itos[tok] for tok in next(iter(train_iter)).text[0])
"without a doubt this is one of the worst films i've ever wasted money on! the plot is, erm sorry, did i say there was a plot? the scariest moment was <unk> nope can't think of one! the best special effect that had me hiding under the bed covers <unk> nope can't think of one for that either. you knew who the killer was right from the start. there was nothing scary about the whole movie, in fact the only two vaguely interesting bits were when you saw the kid sister, <unk> in the shower and when you saw nurse <unk> take her top off. this film should only be watched to get an idea of how not to make a horror movie!!! <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>"

@n8henrie thanks for sharing your code. i was literally running into the same issue… but glad i got a similar accuracy as you on imdb