Value of [CLS] Token for Transformer Encoders

Hi everyone,

i want to use a transformer encoder for sequence classification. Following the idea of BERT, I want to prepend a [CLS] token to the input sequence. After forwardpassing my sequence through the transformer encoder, I plan to use the encoding of the [CLS] token as representation for the whole sequence.

As far as I understand, it doesn’t make a difference with what value I initialize this [CLS] token, as the transformer will (hopefully) learn to not attend that value of the [CLS] token anyways, as it doesn’t contain any information about the sequence.

Is there any reason why I shouldn’t initialize this token randomly during each forwardpass?
This would lead to the value for the [CLS] token being different during each forwardpass however as far as I think it shouldn’t change anything.

Here is a minimal example of what I am planning to do:

class ClassificationTransformer(nn.Module):
    def __init__(self, embedding_dim, n_head, depth):
        super(ClassificationTransformer, self).__init__()
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=n_head,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer,
            num_layers=depth
        )

    def forward(self, x):
        # assume x of shape [B, T, F]
        cls_token = torch.randn((x.shape[0], 1, x.shape[-1]), device=x.device)
        tokens = torch.column_stack((cls_token, x)) # tokens is of shape [B, 1+T, F]
        encodings = self.transformer_encoder(x)
        representation = encodings[:,0,:]
        return representation

Thanks for helping! :slight_smile:

The transformer generates query, key and value from the input. While you have a point that value might not be needed, you do want at least the query part to work.

Also, I’d probably ask the question the other way round and not use randomness unless I have a reason to do so.

Best regards

Thomas

Hi Thomas,

thanks a lot for your answer.
What would you propose to use instead as [CLS] token? I thought that using a tensor of all zeros could maybe lead to bad computational characteristics.
Do you think setting the value for the [CLS] token randomly, but only one time (e.g. in the init method) so that it is fixed during training, would be a better option?

The typical thing (GPT, vision transformer, …) is to make it a learned parameter, e.g. here:

Best regards

Thomas

1 Like

Thank you so much, this is exactly what I was looking for :slight_smile: