Hello,
I could not find the solution from anywhere.
Please help me with this problem.
I trained my model with batch size of 32 (with 3 GPUs).
There are Batchnorm1ds in the model. ( + some dropouts)
During testing,
I checked
model.eval()
track_running_stats = False
When I load a sample test data x, and process with the model, model(x), the result is totally different from the outputs during training.
For example, let’s say the model outputs range 0~0.99 during training with batchsize of 32, while 0~0.05 with batchsize of 1 during testing.
In order to examine further, I loaded 2 or more test data from the dataloader and processed through the model with different batch size and saw different outputs even if the data were the same.
For example, I loaded 4 items (x) from the loader,
x1 = x[:2] # batchsize 2
x2 = x[2:] # batchsize 2
x11 = torch.cat([x[0], x[0]], axis=0) # batchsize 2 of the same data
y0 = model(x[0:1]) # batchsize 1
y1 = model(x1) # batchsize 2
y2 = model(x2) # batchszie 2
y11 = model(x11) # batchsize 2 of the same data
y = model(x) # batchsize 4
print(torch.allclose(y1, y[:2])) # False. y[:2] is different from y1
print(torch.allclose(y2, y[2:])) # False. y[2:] is also different from y2
print(torch.allclose(y11[0], y11[1])) # True
print(torch.allclose(y1[0], y11[0])) # False. y11[0] outputs the same as y11[1], but different from y1[0]
print(torch.allclose(y1[0], y0[0]), torch.allclose(y11[0], y0[0])) # False, False. y0 is also different from y11[0] or y1[0]
The problems are
- The model results in different values according to the batch size during testing.
- y[:2] is different from y1, and y[2:] is also different from y2. y0 is also different from y11[0] or y1[0]
-
Especially, if the batch size is 1 as y0 case, the output histogram ranges 0~0.05
(which is not intended) while case of batchsize 2 or more with different items results in 0~0.99 (which is as intended during training). -
The model results in the same value if the batchsize is increased manually with the same data. y11[0]==y11[1] returns True. This seems correct but, the histogram still ranges 0~0.05.
I think the problems are due to the Batchnorm1ds.
Can someone help me or give a hint to solve them?