How can I extend TripletMarginLoss to a quadruple?

TripletMarginLoss measures the relative similarity between three embeddings: a, p and n (i.e. anchor, positive example and negative example, respectively) and it penalizes a higher similarity between the anchor and the negative example as follows:

L(a,p,n) = max{ d(a​,p​) − d(a​,n​) + margin, 0 }

My goal is to extend this loss function to a quadruple a, p, n1 and n2. In other words, I want to penalize instances where either d(a​,n1​) and/or d(a​,n2) are smaller than d(a​,p​), where d is a distance measure such as the Cosine Similarity.
In short, I want a and p to be the most similar items in each quadruple, where p is the output of the model I’m training. Intuitively, I believe this could be achieved as follows:

L(a,p,n1,n2) = max{ d(a​,p​) − d(a​,n1​) + margin, 0 } + max{ d(a​,p​) − d(a​,n2) + margin, 0 }

My question is, how can I adapt the existing TripletMarginLoss to this case? Is it enough to compute TripletMarginLoss twice, once with n1 and once with n2, and then sum the results as follows?

triplet_loss_1 = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
triplet_loss_2 = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)

anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative_1 = torch.randn(100, 128, requires_grad=True)
negative_2 = torch.randn(100, 128, requires_grad=True)

output_1 = triplet_loss_1(anchor, positive, negative_1)
output_2 = triplet_loss_2(anchor, positive, negative_2)
total_loss = output_1 + output_2

total_loss.backward()