I train a BERT model along with a loss function.
The system makes sentence representations from BERT; computes cosine similarity between sentences; the loss function applies linear interpolation between the cosine similarity values and some constant values and the interpolation weight is supposed to be trainable.
When I use nn.CosineEmbeddingLoss instead of my custom loss function, it runs without a problem. However, the above gives error message halfway through training (around 400-th iteration), as memory usage increases and eventually it goes out of memory.
I suspect that memory is not cleared properly somewhere but I cannot figure out where.
Could you point out my mistake or suggest what might be wrong?
I would greatly appreciate your comment. Thank you!
Here is the train loop:
def train(dataloader, model):
Loss_main = list()
Loss = list()
logger.info("starting train")
model.train()
cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
for i, batch in enumerate(dataloader):
que_iis = batch[0].to(device)
que_ams = batch[1].to(device)
art_iis = batch[2].to(device)
art_ams = batch[3].to(device)
labels = torch.where(batch[4]==0, -1, 1).to(device)
article_ids = batch[5]
section_ids = [art_to_hier[a.item()][-1] for a in article_ids]
tfidf_sim = batch[6].to(device)
query_out = model(input_ids=que_iis, attention_mask=que_ams).last_hidden_state
article_out = model(input_ids=art_iis, attention_mask=art_ams).last_hidden_state
query_emb = _mean_pooling(query_out, que_ams)
article_emb = _mean_pooling(article_out, art_ams)
sim = cos_sim(article_emb, query_emb)
loss = loss_fn(sim, tfidf_sim, labels)
loss = torch.mean(loss)
Loss_main.append(loss.item())
loss = loss / args.accumulation_size
loss.backward()
if (i + 1) % args.accumulation_size == 0:
if args.grad_clip > 0.:
torch.nn.utils.clip_grad_norm_(
itertools.chain(model.parameters(), loss_fn.parameters()), max_norm=args.grad_clip, norm_type=2
)
optimizer.step()
optimizer.zero_grad()
Loss.append(loss.item())
Loss = sum(Loss) / len(Loss)
Loss_main = sum(Loss_main) / len(Loss_main)
return Loss, Loss_main, loss_fn.lerp_weight
The loss function is defined as follows:
class CosSimLossWithLerp(nn.Module):
def __init__(self):
super().__init__()
self.lerp_weight = nn.Parameter(torch.tensor([0.5]))
def forward(self, sim, tfidf_sim, target):
sim = torch.lerp(sim, tfidf_sim, self.lerp_weight)
loss = [1 - sim[i] if target[i] == 1 else max(torch.tensor(0., requires_grad=True).to(device), sim[i]) for i in range(len(sim))]
loss = torch.stack(loss)
loss = torch.mean(loss)
return loss
and called outside the train loop:
loss_fn = CosSimLossWithLerp()
loss_fn.to(device)
I included the loss function parameters in the optimizer:
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), loss_fn.parameters()), lr=args.lr1)