Training parameters in loss function as well as model

I train a BERT model along with a loss function.
The system makes sentence representations from BERT; computes cosine similarity between sentences; the loss function applies linear interpolation between the cosine similarity values and some constant values and the interpolation weight is supposed to be trainable.

When I use nn.CosineEmbeddingLoss instead of my custom loss function, it runs without a problem. However, the above gives error message halfway through training (around 400-th iteration), as memory usage increases and eventually it goes out of memory.

I suspect that memory is not cleared properly somewhere but I cannot figure out where.
Could you point out my mistake or suggest what might be wrong?
I would greatly appreciate your comment. Thank you!

Here is the train loop:

def train(dataloader, model):
    Loss_main = list()
    Loss = list()"starting train")
    cos_sim = nn.CosineSimilarity(dim=1, eps=1e-6)
    for i, batch in enumerate(dataloader):
        que_iis = batch[0].to(device)
        que_ams = batch[1].to(device)
        art_iis = batch[2].to(device)
        art_ams = batch[3].to(device)
        labels = torch.where(batch[4]==0, -1, 1).to(device)
        article_ids = batch[5]
        section_ids = [art_to_hier[a.item()][-1] for a in article_ids]
        tfidf_sim = batch[6].to(device)
        query_out = model(input_ids=que_iis, attention_mask=que_ams).last_hidden_state
        article_out = model(input_ids=art_iis, attention_mask=art_ams).last_hidden_state
        query_emb = _mean_pooling(query_out, que_ams)
        article_emb = _mean_pooling(article_out, art_ams)
        sim = cos_sim(article_emb, query_emb)
        loss = loss_fn(sim, tfidf_sim, labels)
        loss = torch.mean(loss)
        loss = loss / args.accumulation_size
        if (i + 1) % args.accumulation_size == 0:
            if args.grad_clip > 0.:
                    itertools.chain(model.parameters(), loss_fn.parameters()), max_norm=args.grad_clip, norm_type=2

    Loss = sum(Loss) / len(Loss)
    Loss_main = sum(Loss_main) / len(Loss_main)
    return Loss, Loss_main, loss_fn.lerp_weight

The loss function is defined as follows:

class CosSimLossWithLerp(nn.Module):
    def __init__(self):
        self.lerp_weight = nn.Parameter(torch.tensor([0.5]))

    def forward(self, sim, tfidf_sim, target):
        sim = torch.lerp(sim, tfidf_sim, self.lerp_weight)
        loss = [1 - sim[i] if target[i] == 1 else max(torch.tensor(0., requires_grad=True).to(device), sim[i]) for i in range(len(sim))]
        loss = torch.stack(loss)
        loss = torch.mean(loss)
        return loss

and called outside the train loop:

loss_fn = CosSimLossWithLerp()

I included the loss function parameters in the optimizer:

optimizer = torch.optim.Adam(itertools.chain(model.parameters(), loss_fn.parameters()), lr=args.lr1)