I’m training a pre-trained transformer with contrastive pairs using nn.MSELoss(). While loss is update after each batch, it does not change after each epoch.
def train_model():
dataset = read_train_and_val_data()
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/stsb-mpnet-base-v2',do_lower_case=True)
dataset = dataset.map(
lambda x: tokenizer(
x['lhs'], max_length=128, padding='max_length',
truncation=True, return_tensors='pt'
), batched=True, batch_size=2000, num_proc=32
)
dataset = dataset.rename_column('input_ids', 'lhs_ids')
dataset = dataset.rename_column('attention_mask', 'lhs_mask')
dataset = dataset.map(
lambda x: tokenizer(
x['rhs'], max_length=128, padding='max_length',
truncation=True, return_tensors='pt'
), batched=True, batch_size=2000, num_proc=32
)
dataset = dataset.rename_column('input_ids', 'rhs_ids')
dataset = dataset.rename_column('attention_mask', 'rhs_mask')
dataset = dataset.remove_columns(['rhs', 'lhs'])
# Initialize dataloader
dataset.set_format(type='torch', output_all_columns=True)
#save dataset to pickle file
with open("train_dataset.pk", "wb") as train_file:
pickle.dump(dataset, train_file)
with open("train_dataset.pk", "rb") as train_file:
dataset = pickle.load(train_file)
batch_size = 128
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=16)
# model = nn.DataParallel( BertModel.from_pretrained('bert-base-uncased'))
model = AutoModel.from_pretrained('sentence-transformers/stsb-mpnet-base-v2')
model = nn.DataParallel(model)
model.to(device)
cos_sim = torch.nn.CosineSimilarity()
loss_func = nn.MSELoss()
# move layers to device
# initialize Adam optimizer
optim = torch.optim.Adam(model.parameters(), lr=2e-5)
# setup warmup for first ~10% of steps
total_steps = int(len(dataset)/batch_size)
warmup_steps = int(0.1 * total_steps)
scheduler = get_linear_schedule_with_warmup(
optim, num_warmup_steps=warmup_steps,
num_training_steps=total_steps-warmup_steps
)
from tqdm.auto import tqdm
epochs = 8
for epoch in range(epochs):
losses = []
model.train()
output_dir = '/home/ubuntu/research/research_sandbox/wanjiru/nlp/gro_source_mapping/gro_nlp/src/models/torch_models/model_epoch_{}/'.format(epoch)
loop = tqdm(loader, leave=True)
for batch in loop:
# batch.to(accelerator.device)
torch.autograd.set_detect_anomaly = True
# zero all gradients on each new step
optim.zero_grad()
# prepare batches and more all to the active device
anchor_ids = batch['lhs_ids'].to(device)
anchor_mask = batch['lhs_mask'].to(device)
pos_ids = batch['rhs_ids'].to(device)
pos_mask = batch['lhs_mask'].to(device)
#forward pass
a = model(
anchor_ids, attention_mask=anchor_mask
)
p = model(
pos_ids, attention_mask=pos_mask
)
a = mean_pool(a, anchor_mask)
p = mean_pool(p, pos_mask)
scores = cos_sim(a, p)
scores.to(device)
labels = batch['score'].to(device)
loss = loss_func(scores, labels)
losses.append(loss.item())
# backward pass
loss.backward()
optim.step()
# update learning rate scheduler
scheduler.step()
# update the TDQM progress bar
loop.set_description(f'Epoch {epoch}')
loop.set_postfix(loss=loss.item())
print(losses)
# save model after each epoch
if not os.path.exists(output_dir):
os.makedirs(output_dir)
output_model_file = os.path.join(output_dir, WEIGHTS_NAME)
output_config_file = os.path.join(output_dir, CONFIG_NAME)
#model_to_save = model.module if hasattr(model, 'module') else model
model_to_save = model.module
torch.save(model_to_save.state_dict(), output_model_file)
model_to_save.config.to_json_file(output_config_file)
tokenizer.save_vocabulary(output_dir)
# evaluate model after each epoch:
run_eval(output_dir, epoch)
train_model()```