I am implementing a triplet network in Pytorch where I use 3 forward passes and a single backward pass, similar to what is described here. My model is trained by optimizing triplet loss. After running the code, everything works as expected.

However, I am still confused about the computation of backprop in this implementation. Specifically, as PyTorch accumulates the derivatives, the gradients of the triplet loss w.r.t. to the last linear layer (embedding) (shown here) always add up to zero. Of course, this cannot be true as the network eventually learns meaningful embeddings.

This is surprising indeed.
Do you see this gradient being non-zero after some time during training?
Also if you ignore gradients from one branch, is it non-zero?

# To ignore gradient from one branch, just detach it:
base = net(inp_base)
pos = net(inp_pos)
neg = net(inp_neg)
# To ignore gradients from the neg branch
neg = neg.detach()
loss = crit(base, pos, neg)
loss.backward()

I attached two photos showing the gradients of the embedding layer once without detaching any embedding and once after detaching the pos embeddings. I also fixed the seed and the model (both are at same epoch and optimizer state). The model is trained sufficiently long at the time of showing the gradients. The gradients are not equal but non-zero, which is expected.

What I do not understand is that adding up the gradients (the link is above) gives zero in theory. This means the network will not update its parameters. However, this is luckily not the case. Now I am simply confused how triplet loss is actually optimized given my implementation. (Ps. I might have overlooked a detail in the theory part as well.)

I am sorry if that was confusing. The code never gives 0 gradient at any case. But I am wondering how exactly PyTorch gets the gradients.

Isn’t it that when I do loss.backward(), the gradients will be accumulated over the base, pos and neg instances? If so, the weights at the embedding layer are actually the same tensors, thus I expect they will be updated by the sum of gradients (times const). Surprisingly, in theory, this sum turns out to be always zero, as shown in the figure. But since I do not get zeros, there must be something that I do not understand.

Right but these formulas give the gradients wrt to the output of the net, not the weights. And you never actually sum the gradient contribution of the outputs (because these are different Tensors), only the weights and actually shared and so gradients are accumulated.

Yes it does explain a lot, thank you very much @albanD!

I was implicitly assuming that df/dW_last are the same for all the three instances because the weights are identical. It is not the case; the derivatives are input-dependent. So instead of having a weighted sum of the 3 grads, I was getting Const * the sum of grads (which is 0).