Hello everyone,
I am currently working on implementing transfer learning to address catastrophic forgetting in my text generation model. The pre-trained model is based on GPT2 and initially trained on financial data (referred to as dataset A). Now, I aim to fine-tune it on a different dataset, dataset B. However, I have encountered an issue: despite training for 2 hours, there is no improvement, and the loss remains constant.
I suspect that either my custom model is not constructed correctly, or there might be an issue with my training loop. I would appreciate any insights or suggestions you may have
class CustomModel(torch.nn.Module):
def __init__(self, pretrained_model, config):
super(CustomModel, self).__init__()
self.transformer = pretrained_model
self.config = config
self.ffn1 = torch.nn.Sequential(
torch.nn.Linear(self.config.vocab_size, self.config.n_embd),
torch.nn.GELU(),
torch.nn.Linear(self.config.n_embd, self.config.n_embd)
)
self.layer_norm1 = torch.nn.LayerNorm(self.config.n_embd)
self.ffn2 = torch.nn.Sequential(
torch.nn.Linear(self.config.n_embd, 2*self.config.n_embd),
torch.nn.GELU(),
torch.nn.Linear(2*self.config.n_embd, self.config.n_embd)
)
self.layer_norm2 = torch.nn.LayerNorm(self.config.n_embd)
self.Linear = torch.nn.Linear(self.config.n_embd, self.config.vocab_size)
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
outputs = self.transformer(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
hidden_states = self.ffn1(outputs.logits)
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.ffn2(hidden_states)
hidden_states = self.layer_norm2(hidden_states)
logits = self.Linear(hidden_states)
return logits
The training loop:
epochs=3
step = 0
model.train()
total_loss = 0.
for epoch in range(epochs):
train_loader = tqdm(train_loader, total=len(train_loader))
# I also applied gradual unfreezing
if epoch < 2:
for param in model.transformer.parameters():
param.requires_grad = False
else:
for param in model.transformer.parameters():
param.requires_grad = True
for batch in train_loader:
input_ids, attention_mask, token_type_ids, targets = batch
input_ids = input_ids.to(device)
targets = targets.to(device)
optimizer.zero_grad()
outputs = model(input_ids)
loss = criterion(outputs.view(-1, config.vocab_size), targets.view(-1))
loss.backward()
optimizer.step()
scheduler.step()
total_loss += loss.item()
train_loader.set_description(f"Epoch {epoch+1}")
train_loader.set_postfix(loss=loss.item())