I am implementing a triplet network in Pytorch where I use 3 forward passes and a single backward pass, similar to what is described here. My model is trained by optimizing triplet loss. After running the code, everything works as expected.
However, I am still confused about the computation of backprop in this implementation. Specifically, as PyTorch accumulates the derivatives, the gradients of the triplet loss w.r.t. to the last linear layer (embedding) (shown here) always add up to zero. Of course, this cannot be true as the network eventually learns meaningful embeddings.
Any explanation for this fallacy?