Separating batch of images before or after embedding causes massively different results. Why?

Hello.
I am trying to understand a weird phenomenon and I really hope you can help me, as it seems to underpin something quite fundamental that I am missing here.

I have a batch of images. The first k are passed through an embedding network, the other k*n are passed through the same network and also through an RNN network which outputs k hidden states(I believe the details are not important). I then input the k hidden states and the k embedding from the first group to a relation networks which compares them.

Everything else equal, I noticed that the order of embedding at the beginning of the training step is important in such a way that, in one case, the gradient rapidly goes to zero and the network doesn’t learn anything, and in the other it learns rapidly.
x is the batch of images.
This is the one that doesn’t work.

A = x[:k]
B = x[k:]
emb_A = model.image_embedding(A)
emb_B = model.image_embedding(B)
...

I have also tried to use two different networks (with same architecture) for emb_A and emb_B, with the same unsatisfying results.

this is the one that works:

emb_all = model.image_embedding(x)
emb_A = emb_all[:k]
emb_B = emb_all[k:]
...

I don’t understand why it should matter whether I separated the images before or after doing the embedding. I do get a hint that by reusing the same network on the two groups the gradient may have trouble converging - but as I said the same thing happens when 2 networks are used, which makes me believe that its due to the slicing operation.
Does anyone have any idea?
Thanks

The operations should be equal up to the error created by the limited floating point precision as seen in this example:

# setup
x = torch.randint(0, 100, (100, 10))
image_embedding = nn.Embedding(100, 300)
k = 50

# check forward
A = x[:k]
B = x[k:]
emb_A1 = image_embedding(A)
emb_B1 = image_embedding(B)

emb_all = image_embedding(x)
emb_A2 = emb_all[:k]
emb_B2 = emb_all[k:]

print((emb_A1 - emb_A2).abs().max())
print((emb_B1 - emb_B2).abs().max())

# check backward
emb_A1.mean().backward()
grad_A1 = image_embedding.weight.grad.clone()

image_embedding.zero_grad()
emb_A2.mean().backward(retain_graph=True)
grad_A2 = image_embedding.weight.grad.clone()
print((grad_A1 - grad_A2).abs().max())

emb_B1.mean().backward()
grad_B1 = image_embedding.weight.grad.clone()

image_embedding.zero_grad()
emb_B2.mean().backward()
grad_B2 = image_embedding.weight.grad.clone()
print((grad_B1 - grad_B2).abs().max())

Thank you, so the error must be somewhere else

Just a wild guess. Do you use batchnorm in your network?
If so, if A belongs to a different distribution of images than B, then batchnorm statistics may be widely different and they would have a tough time learning statistics, when you pass them individually I guess. In the case that you pass them together, they might get normalized properly.

1 Like