Transformer layer weights not updating

Hi:

I am trying to classify simple sequence data.
Examples:
([‘bbRcabbSab’,
‘abRbbccSca’,
‘baabRcaaaS’,
‘abRbabaSaa’,
‘acRcbacSbc’,
‘bcRbaacSbc’,
‘RbbaaabaSc’,
‘abRacabSba’,
‘RabcbSbcba’,
‘ccRbcaaSba’],
[1, 1, 0, 1, 1, 1, 0, 1, 0, 1])
If R is in position 2 and S is in position len(s)-3 then the target is 1 and 0 otherwise.

I am trying to use a transformer for this. The model I am using is:

class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # Modified version from: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
        # max_len determines how far the position can have an effect on a token (window)
        
        self.dropout = nn.Dropout(dropout_p)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1) # 0, 1, 2, 3, 4, 5
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0)
    
        # register buffer so it is not included in model paramters and no gradient is computed
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        pos_encoded = self.pos_encoding[:, :token_embedding.size(1), :]
        return self.dropout(token_embedding + pos_encoded)

class TransformerLinearNet(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        dim_hidden,
        num_heads,
        num_encoder_layers,
        dropout_p,
        max_len
    ):
        super(TransformerLinearNet, self).__init__()

        self.dim_hidden = dim_hidden
        self.pos_encoder = PositionalEncoding(dim_model, dropout)
        encoder_layers = TransformerEncoderLayer(dim_model, num_heads, dim_hidden, dropout_p)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)
        self.encoder = nn.Embedding(num_tokens, dim_model)
        self.d_model = dim_model
        self.pooler = nn.Linear(dim_model, dim_model)
        self.activation = nn.Tanh()
        self.fc = nn.Linear(dim_model, 1)
        self.sigmoid = nn.Sigmoid() 
        self.dropout = dropout_p
        
        # CLS and SEP tokens
        cls_embed = torch.randn(1,1,dim_model,requires_grad=False) * math.sqrt(self.d_model)
        sep_embed = torch.randn(1,1,dim_model,requires_grad=False) * math.sqrt(self.d_model)
        self.register_buffer("cls_embed",cls_embed)
        self.register_buffer("sep_embed",sep_embed)
        
        # self.init_weights()

    def init_weights(self) -> None:
        torch.nn.init.xavier_uniform(self.encoder.weight)
        torch.nn.init.xavier_uniform(self.fc.bias)
        torch.nn.init.xavier_uniform(self.fc.weight)
        torch.nn.init.xavier_uniform(self.pooler.bias)
        torch.nn.init.xavier_uniform(self.pooler.weight)
        
    
    def forward(self, src: Tensor, pad_mask: Tensor) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        batch_size = src.shape[0]
        t = self.encoder(src) * math.sqrt(self.d_model)
        # Add the fixed cls_token and sep_token
        t = torch.cat((self.cls_embed.repeat(batch_size,1,1),t,self.sep_embed.repeat(batch_size,1,1)), dim=1)
        t = self.pos_encoder(t)
        if pad_mask is None:
            out = self.transformer_encoder(t)
        else:
            out = self.transformer_encoder(t, src_key_padding_mask=pad_mask)
        # Feed the CLS encoding to the decoder layer for classification
        out = self.pooler(out[torch.arange(batch_size), 0,:].squeeze())
        out = self.activation(out)
        out = self.fc(out)
        out = self.sigmoid(out)
        out = out.view( batch_size, -1 )
        out = out[:,-1]
        
        return out

Training is:

emsize = 10 # embedding dimension
d_hid = 10 # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 5  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 1  # number of heads in nn.MultiheadAttention
dropout = 0.5  # dropout probability
model = TransformerLinearNet(ntokens, emsize, nhead, d_hid, nlayers, dropout, pad_length).to(device)

model.train() 

#### compute the accuracy of the model predictions
def accuracy(outputs, targets):
    labels = torch.zeros(len(outputs)).to(device)
    ones_index = torch.where(outputs > 0.5)[0]
    labels[ones_index] = 1.0
    return torch.sum(labels==targets.float())/float(len(labels))

lr = 0.001
criterion = nn.BCELoss()
optimizer = torch.optim.Adam( model.parameters(), lr=lr )

### training the model with TransformerLinearNet

epochs = 10
history = {'loss':[], 'val_loss':[], 'accuracy':[], 'val_accuracy':[]}
for i in range( epochs ):
    
    running_loss = 0
    num_batches = 0 
    training_accuracy = 0
    val_accuracy = 0
    
    for inputs, targets in train_loader: 
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True) 
        optimizer.zero_grad()
        
        pad_mask = create_pad_mask(inputs, pad_token=0)
        pad_mask = None
        
        output = model( inputs, pad_mask )
        loss = criterion(output, targets.float())
        loss.backward()
        running_loss += loss
        
        optimizer.step()
        
        # compute training accuracy
        training_accuracy += accuracy(output, targets.float())
    
        num_batches += 1 
        
    running_loss = running_loss.item()/num_batches  
    training_accuracy = training_accuracy.item()/num_batches
    
    val_loss = 0
    num_batches = 0 
    
    with torch.no_grad():
        # compute test loss and accuracy
        for test_inputs, test_labels in test_loader:
            test_inputs, test_labels = test_inputs.to(device, non_blocking=True), test_labels.to(device, non_blocking=True)

            num_batches += 1

            pad_mask = create_pad_mask(test_inputs, pad_token=0)    
            pad_mask = None

            test_outputs = model(test_inputs, pad_mask)

            val_loss += criterion(test_outputs, test_labels.float())
            val_accuracy += accuracy(test_outputs, test_labels.float())

        val_accuracy = val_accuracy.item()/num_batches 
        val_loss = val_loss.item()/num_batches
        
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_accuracy)
    
    history['loss'].append(running_loss)
    history['accuracy'].append(training_accuracy)
    
    
    print('epoch',i,'loss', "{:.2f}".format(running_loss),'training accuracy', "{:.2f}".format(training_accuracy), 'val accuracy', "{:.2f}".format(val_accuracy))
    

But I consistently get the following result where the loss is the same.

epoch 0 loss 0.69 training accuracy 0.49 val accuracy 0.51
epoch 1 loss 0.69 training accuracy 0.50 val accuracy 0.50
epoch 2 loss 0.69 training accuracy 0.50 val accuracy 0.50
epoch 3 loss 0.69 training accuracy 0.49 val accuracy 0.51
epoch 4 loss 0.69 training accuracy 0.51 val accuracy 0.51
epoch 5 loss 0.69 training accuracy 0.50 val accuracy 0.49
epoch 6 loss 0.69 training accuracy 0.50 val accuracy 0.50
epoch 7 loss 0.69 training accuracy 0.50 val accuracy 0.51
epoch 8 loss 0.69 training accuracy 0.50 val accuracy 0.50
epoch 9 loss 0.69 training accuracy 0.50 val accuracy 0.51

Any insight into what I am doing wrong would be most appreciated.

Thank you!

The problem was that I had swapped the parameters nhead and dim_head when passing to the TransformerLinearNet instantiation. Also, I just used the mean of the transformer encoder outputs as it was not working with the CLS token.

Here is the working model code.

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import torch.nn.functional as F 
import torch.optim as optim

from torch.utils.data import TensorDataset, DataLoader 
class TransformerLinearNet(nn.Module):
    def __init__(
        self,
        num_tokens,
        dim_model,
        dim_hidden,
        num_heads,
        num_encoder_layers,
        dropout_p,
        max_len
    ):
        super(TransformerLinearNet, self).__init__()

        self.dim_hidden = dim_hidden
        self.pos_encoder = PositionalEncoding(dim_model, dropout)
        encoder_layers = TransformerEncoderLayer(dim_model, num_heads, dim_hidden, dropout_p)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_encoder_layers)
        self.encoder = nn.Embedding(num_tokens+1, dim_model, padding_idx=0)
        self.d_model = dim_model
        self.fc = nn.Linear(dim_model, 1)
        self.sigmoid = nn.Sigmoid() 
        self.dropout = dropout_p
                
        self.init_weights()

    def init_weights(self) -> None:
        torch.nn.init.xavier_uniform_(self.encoder.weight)
        torch.nn.init.xavier_uniform_(self.fc.weight)
        
    
    def forward(self, src: Tensor, pad_mask: Tensor=None) -> Tensor:
        """
        Args:
            src: Tensor, shape [seq_len, batch_size]
            src_mask: Tensor, shape [seq_len, seq_len]

        Returns:
            output Tensor of shape [seq_len, batch_size, ntoken]
        """
        batch_size = src.shape[0]
        t = self.encoder(src) * math.sqrt(self.d_model)
        t = self.pos_encoder(t)
        if pad_mask is None:
            out = self.transformer_encoder(t)
        else:
            out = self.transformer_encoder(t, src_key_padding_mask=pad_mask)
            
        # Use the mean of the encodings as the input to the classifier
        out = out.mean(dim=1)
        out = self.fc(out)
        out = self.sigmoid(out)
        out = out.view( batch_size, -1 )
        out = out[:,-1]
        
        return out

The model above works for the examples provided.

Here is the code to generate the data examples for those interested.

# Add R to the second position from the start and S to the second position from the last
def generate_RS_pos_strings(n_strings, string_length, charset):
    """
    n_strings : number of strings to generate
    string_length : length of the string to be generated
    charset : set of characters to use
    """
    char_list = list( charset )
    string_list = []
    label_list = []
    
    
    for s_counter in range(0,n_strings):
        label = random.randint(0,1)
        this_string_list = []
        
        for i in range(0,string_length):
            this_string_list.append( random.choice( char_list ) )
        if label: 
            this_string_list[2] = 'R'
            this_string_list[string_length-3] = 'S'
        else:
            indices = list(range(string_length))
            indices.remove(2)
            indices.remove(string_length-3)
            index_1 = random.choice(indices)           
            indices.remove(index_1)
            index_2 = random.choice(indices)
            if index_1 < index_2:
                this_string_list[index_1] = 'R'
                this_string_list[index_2] = 'S'
            else:
                this_string_list[index_2] = 'R'
                this_string_list[index_1] = 'S'
        string_list.append( ''.join( this_string_list ) )
        label_list.append( label )
    return string_list, label_list
 
pad_length=10
max_length=pad_length

training_strings, y_train = generate_RS_pos_strings(data_size, max_length, 'abc')
test_strings, y_test = generate_RS_pos_strings(data_size, max_length, 'abc')

# This constructs and returns two dicts, that can be used to map the characters that are in the string
# to ints, and those ints back to the characters

def char_to_integers( mystring ):
    charlist = list( set( list( mystring )))
    nums = range(1,len(charlist)+1)
    c2ndict = dict()
    n2cdict = dict()
    for c,n in zip(charlist,nums):
        c2ndict[c]=n
        n2cdict[n]=c
    return c2ndict, n2cdict

# this inefficiently ensures that we have complete character codinng and decoding
# dictionaries by scanning through the entire training set
c2i, i2c = char_to_integers(''.join(training_strings))
ntokens = len( c2i ) + 1 # we leave 0 as the padding symbol

def string_to_int_vec( s, pad_length, code_dict):
    """
    Converts a string to a vector of ints, using a character-encoding dictionary
    
    s : the string to convert
    padlength : the length to pad the string to, with initial zeros
    code : dict giving the conversion from chars to integers
    """
    slen = len(s)
    assert slen <= pad_length
    v = np.zeros([pad_length])
    startx = pad_length - slen 
    stringlist = list(s)
    for i in range(0,slen):
        v[startx + i] = code_dict[stringlist[i]]
    return v

def strings_to_nparray(strings, pad_length, char_to_integers):
    """
    Converts a list of strings to an numpy array, which can be used as training
    or testing data
    
    strings : a list of strings
    maxlen  : an integer, greater than or equal to the max length of any of the strings
    
    This function converts a list of n strings into a n x maxlen numpy array, containing
    the coded strings
    """
    mat = np.zeros([len(strings),pad_length])
    for i in range(0,len(strings)):
        mat[i,:] = string_to_int_vec( strings[i], pad_length, char_to_integers )
    return mat

x_train = strings_to_nparray(training_strings, pad_length, c2i)
x_test = strings_to_nparray(test_strings, pad_length, c2i)

train_data = TensorDataset( torch.LongTensor( x_train ), torch.LongTensor( y_train ))
test_data = TensorDataset( torch.LongTensor( x_test ), torch.LongTensor( y_test ))

batch_size = 64

train_loader = DataLoader( train_data, shuffle=True, batch_size= batch_size )
test_loader = DataLoader( test_data, shuffle=True, batch_size = batch_size ) # we don't really need to shuffle this

The code still does not work for the following bracket matching examples. If anyone is adventurous to try it out. I was told that matching brackets is quite a tricky problem for transformers as it has high circuit complexity.

Bracket matching examples.
([‘()()()()()’, ‘()()))()()’, ‘())()())()’, ‘(((())))))’, ‘()()))()()’],
[1, 0, 0, 0, 0])