i want to use a transformer encoder for sequence classification. Following the idea of BERT, I want to prepend a [CLS] token to the input sequence. After forwardpassing my sequence through the transformer encoder, I plan to use the encoding of the [CLS] token as representation for the whole sequence.
As far as I understand, it doesn’t make a difference with what value I initialize this [CLS] token, as the transformer will (hopefully) learn to not attend that value of the [CLS] token anyways, as it doesn’t contain any information about the sequence.
Is there any reason why I shouldn’t initialize this token randomly during each forwardpass?
This would lead to the value for the [CLS] token being different during each forwardpass however as far as I think it shouldn’t change anything.
Here is a minimal example of what I am planning to do:
class ClassificationTransformer(nn.Module): def __init__(self, embedding_dim, n_head, depth): super(ClassificationTransformer, self).__init__() encoder_layer = nn.TransformerEncoderLayer( d_model=embedding_dim, nhead=n_head, batch_first=True ) self.transformer_encoder = nn.TransformerEncoder( encoder_layer, num_layers=depth ) def forward(self, x): # assume x of shape [B, T, F] cls_token = torch.randn((x.shape, 1, x.shape[-1]), device=x.device) tokens = torch.column_stack((cls_token, x)) # tokens is of shape [B, 1+T, F] encodings = self.transformer_encoder(x) representation = encodings[:,0,:] return representation
Thanks for helping!