How are batch norm statistics affected during multiple forward passes?

Consider a simple Siamese network. I have inputs from two datasets, x1 and x2. What is the correct technique in forwarding inputs?

net = Siamese()
X = torch.cat([x1, x2], dim=0)
O = net(X)
o1 = O[:x1.shape[0], ...]
o2 = O[x2.shape[0]:, ...]

or alternatively,

net = Siamese()
o1 = net(x1)
o2 = net(x2)

Note that x1 and ‘x2’ are from different datasets.

2 Likes