Adding layers to text generation model (Transfer Learning)

Hello everyone,

I am currently working on implementing transfer learning to address catastrophic forgetting in my text generation model. The pre-trained model is based on GPT2 and initially trained on financial data (referred to as dataset A). Now, I aim to fine-tune it on a different dataset, dataset B. However, I have encountered an issue: despite training for 2 hours, there is no improvement, and the loss remains constant.

I suspect that either my custom model is not constructed correctly, or there might be an issue with my training loop. I would appreciate any insights or suggestions you may have

class CustomModel(torch.nn.Module):
  def __init__(self, pretrained_model, config):
    super(CustomModel, self).__init__()
    self.transformer = pretrained_model
    self.config = config

    self.ffn1 = torch.nn.Sequential(
        torch.nn.Linear(self.config.vocab_size, self.config.n_embd),
        torch.nn.GELU(),
        torch.nn.Linear(self.config.n_embd, self.config.n_embd)
    )
    self.layer_norm1 = torch.nn.LayerNorm(self.config.n_embd)

    self.ffn2 = torch.nn.Sequential(
        torch.nn.Linear(self.config.n_embd, 2*self.config.n_embd),
        torch.nn.GELU(),
        torch.nn.Linear(2*self.config.n_embd, self.config.n_embd)
    )
    self.layer_norm2 = torch.nn.LayerNorm(self.config.n_embd)

    self.Linear = torch.nn.Linear(self.config.n_embd, self.config.vocab_size)

  def forward(self, input_ids, attention_mask=None, token_type_ids=None):
    outputs = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    hidden_states = self.ffn1(outputs.logits)
    hidden_states = self.layer_norm1(hidden_states)

    hidden_states = self.ffn2(hidden_states)
    hidden_states = self.layer_norm2(hidden_states)

    logits = self.Linear(hidden_states)

    return logits

The training loop:

epochs=3
step = 0
model.train()
total_loss = 0.

for epoch in range(epochs):
  train_loader = tqdm(train_loader, total=len(train_loader))

  # I also applied gradual unfreezing
  if epoch < 2:
    for param in model.transformer.parameters():
      param.requires_grad = False
  else:
    for param in model.transformer.parameters():
      param.requires_grad = True

  for batch in train_loader:
    input_ids, attention_mask, token_type_ids, targets = batch

    input_ids = input_ids.to(device)
    targets = targets.to(device)

    optimizer.zero_grad()
    outputs = model(input_ids)
    loss = criterion(outputs.view(-1, config.vocab_size), targets.view(-1))
    loss.backward()
    optimizer.step()
    scheduler.step()
    total_loss += loss.item()
    train_loader.set_description(f"Epoch {epoch+1}")
    train_loader.set_postfix(loss=loss.item())