I have seen some triplet loss implementations in PyTorch, which call model.forward on anchor, positive and negative images; then compute triplet loss and finally call loss.backward and optimizer.step, something like this:
anchor_embed = model.forward(anchor_images)
pos_embed = model.forward(pos_images)
neg_embed = model.forward(neg_images)
loss = triplet_loss.forward(anchor_embed, pos_embed, neg_embed)
optimizer.zero_grad()
loss.backward()
optimizer.step()
I think these implementations are not correct, since it will only consider the outputs of neg_embed = model.forward(neg_images) during bakpropagation, and disregard the anchor and positive images. As far as I know, in PyTorch, gradients are accumulated, but when we call model.forward it will overwrite the previous computations (which makes sense).
Then, the right implementation should concatenate the anchor,positive,negative images and perform a single model.forward, then compute loss and finally loss.backward and optimizer.step.
Is this right? Or am I missing something?
Thanks.