I could not find the solution from anywhere.
Please help me with this problem.
I trained my model with batch size of 32 (with 3 GPUs).
There are Batchnorm1ds in the model. ( + some dropouts)
model.eval() track_running_stats = False
When I load a sample test data x, and process with the model, model(x), the result is totally different from the outputs during training.
For example, let’s say the model outputs range 0~0.99 during training with batchsize of 32, while 0~0.05 with batchsize of 1 during testing.
In order to examine further, I loaded 2 or more test data from the dataloader and processed through the model with different batch size and saw different outputs even if the data were the same.
For example, I loaded 4 items (x) from the loader,
x1 = x[:2] # batchsize 2 x2 = x[2:] # batchsize 2 x11 = torch.cat([x, x], axis=0) # batchsize 2 of the same data y0 = model(x[0:1]) # batchsize 1 y1 = model(x1) # batchsize 2 y2 = model(x2) # batchszie 2 y11 = model(x11) # batchsize 2 of the same data y = model(x) # batchsize 4 print(torch.allclose(y1, y[:2])) # False. y[:2] is different from y1 print(torch.allclose(y2, y[2:])) # False. y[2:] is also different from y2 print(torch.allclose(y11, y11)) # True print(torch.allclose(y1, y11)) # False. y11 outputs the same as y11, but different from y1 print(torch.allclose(y1, y0), torch.allclose(y11, y0)) # False, False. y0 is also different from y11 or y1
The problems are
- The model results in different values according to the batch size during testing.
- y[:2] is different from y1, and y[2:] is also different from y2. y0 is also different from y11 or y1
Especially, if the batch size is 1 as y0 case, the output histogram ranges 0~0.05
(which is not intended) while case of batchsize 2 or more with different items results in 0~0.99 (which is as intended during training).
The model results in the same value if the batchsize is increased manually with the same data. y11==y11 returns True. This seems correct but, the histogram still ranges 0~0.05.
I think the problems are due to the Batchnorm1ds.
Can someone help me or give a hint to solve them?