Very high, increasing loss in transformer model

I am trying to solve a sequence to sequence problem with a transformer model. The data is derived from a set of crossword puzzles.

The positional encoding and transformer classes are as follows:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 3000):
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
    def debug(self, x):
        return x.shape, x.size()

    def forward(self, x: Tensor) -> Tensor:

        x = x +[:, :x.size(1), :]
        return self.dropout(x)

class Transformer(nn.Module):

    def __init__(

        self.model_type = "Transformer"
        self.dim_model = dim_model

        self.positional_encoder = PositionalEncoding(
            d_model=dim_model, dropout=dropout_p, max_len=3000
        self.embedding = nn.Embedding.from_pretrained(vec_weights, freeze=False)#nn.Embedding(num_tokens, dim_model)
        self.transformer = nn.Transformer(
            batch_first = batch_first
        self.out = nn.Linear(dim_model, num_tokens)
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        src = self.embedding(src)*math.sqrt(self.dim_model)
        tgt = self.embedding(tgt)*math.sqrt(self.dim_model)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)
        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)
        return out
    def get_tgt_mask(self, size) -> torch.tensor:
        mask = torch.tril(torch.ones(size, size) == 1) 
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf'))
        mask = mask.masked_fill(mask == 1, float(0.0))
        return mask
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        return (matrix == pad_token) 

The input tensors are a source tensor of size N by S, where N is the batch size and S is the source sequence length, and a target tensor of size N by T, where T is the target sequence length. S is about 10 and T is about 5, while the total number of items is about 160,000-200,000, divided into batch sizes of 512. They are torch.IntTensors, with elements in the range from 0 to V, where V is the vocabulary length.

The first layer is an embedding layer that takes the input from N by S to N by S by E, where E is the embedding dimension (300), or to N by T by E in the case of the target. The second layer adds position encoding without changing the shape. Then both tensors are passed through the transformer layer, which outputs an N by T by E tensor. Finally, we pass this output through a linear layer, which produces an N by T by V output, where V is the size of the vocabulary used in the problem. Here V is about 56,697. The most frequent tokens (words) appear about 50-60 times in the target tensor.

The transformer class also contains the functions for implementing the masking matrices.

Then we create the model and run it (this process is wrapped in a function).

device = "cuda"

src_train, src_test =, [int(0.9*len(src_t)), len(src_t)-int(0.9*len(src_t))])
src_train, src_test = src_train[:512], src_test[:512]
tgt_train, tgt_test =, [int(0.9*len(tgt_t)), len(tgt_t)-int(0.9*len(tgt_t))])
tgt_train, tgt_test = tgt_train[:512], tgt_test[:512]
train_data, test_data = list(zip(src_train, tgt_train)), list(zip(src_test, tgt_test))
train, test =,

model = Transformer(num_tokens=ntokens, dim_model=300, num_heads=2, num_encoder_layers=3, num_decoder_layers=3, batch_first = True, dropout_p=0.1).to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0000001)

n_epochs = 50

def train_model(model, optimizer, loss_function, n_epochs):
    for epoch in range(n_epochs):
        print(f"Starting epoch {epoch}")
        for batch, data in enumerate(train):
            x, y = data
            if batch%100 == 0:
                print(f"Batch is {batch}")
            batch += 1

            x, y = torch.tensor(x).to(device), torch.tensor(y).to(device)
            y_input, y_base = y[:, :-1], y[:, 1:]
            y_input, y_base =,

            tgt_mask = model.get_tgt_mask(y_input.shape[1]).to(device)
            pad_token = vocabulary_table[embeddings.key_to_index["/"]]
            src_pad_mask = model.create_pad_mask(x, pad_token).to(device)
            tgt_pad_mask = model.create_pad_mask(y_input, pad_token).to(device)
            z = model(x, y_input, tgt_mask, src_pad_mask, tgt_pad_mask)
            z = z.permute(0, 2, 1).to(device)
            y_base = y_base.long().to(device)
            loss = loss_function(z, y_base).to(device)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)
            loss_value += float(loss)
            if batch%100 == 0:
                print(f"For epoch {epoch}, batch {batch} the cross-entropy loss is {loss_value}")
            #Free GPU memory.
            del z
            del x
            del y
            del y_input
            del y_base
            del loss
    return model.parameters(), loss_value

Basically, we split the data into test and training sets and use an SGD optimizer and cross-entropy loss. We create a masking matrix for the padding for both the target and source tensors, and a masking matrix for unseen elements for the target tensor. We then do the usual gradient update steps. Right now, there is no validation loop, because I cannot even get the training loss to decrease.

The loss is very high, reaching more than 1000 after 100 batches. More concerningly, the loss also increases rapidly during training, rather than decreasing. In the code that I included, I tried clipping the gradients, lowering the learning rate, and using a much smaller sample to debug the code.

Does anyone know what could be causing this behavior?