Hi, I’m having a different (better) accuracy when .eval() is not called and I can’t figure out why is that.
I use MSE as the loss function, getting an error of 0.002 (avg) in the training set and 0.08 (avg) in the validation/test set, all of this calling .eval() . When I comment the .eval() line, the MSE of the test set goes down to 0.002.
Here’s my model:
class Net3(nn.Module): # U-net without skip connections
def __init__(self, kernel_sz):
super(Net3, self).__init__()
i_size = 16
self.mp = nn.MaxPool3d(kernel_size=2, padding=0, stride=2, dilation=1, return_indices=True) # Reduces size by half
self.batchNorm_i = nn.BatchNorm3d(i_size)
self.batchNorm_2i = nn.BatchNorm3d(i_size * 2)
self.batchNorm_4i = nn.BatchNorm3d(i_size * 4)
self.batchNorm_8i = nn.BatchNorm3d(i_size * 8)
self.batchNorm_16i = nn.BatchNorm3d(i_size * 16)
self.dropout50 = nn.Dropout3d(p = 0.5) # p is the probability of being zeroed
self.dropout20 = nn.Dropout3d(p = 0.2)
self.dropout10 = nn.Dropout3d(p = 0.1)
self.conv1_i = nn.Conv3d(1, i_size, kernel_size=kernel_sz, padding=1)
self.convi_i = nn.Conv3d(i_size, i_size, kernel_size=kernel_sz, padding=1)
self.convi_2i = nn.Conv3d(i_size, i_size * 2, kernel_size=kernel_sz, padding=1)
self.conv2i_2i = nn.Conv3d(i_size * 2, i_size * 2, kernel_size=kernel_sz, padding=1)
self.conv2i_4i = nn.Conv3d(i_size * 2, i_size * 4, kernel_size=kernel_sz, padding=1)
self.conv4i_4i = nn.Conv3d(i_size * 4, i_size * 4, kernel_size=kernel_sz, padding=1)
self.conv4i_8i = nn.Conv3d(i_size * 4, i_size * 8, kernel_size=kernel_sz, padding=1)
self.conv8i_8i = nn.Conv3d(i_size * 8, i_size * 8, kernel_size=kernel_sz, padding=1)
self.conv8i_16i = nn.Conv3d(i_size * 8, i_size * 16, kernel_size=kernel_sz, padding=1)
self.conv16i_16i = nn.Conv3d(i_size * 16, i_size * 16, kernel_size=kernel_sz, padding=1)
self.conv16i_8i = nn.Conv3d(i_size * 16, i_size * 8, kernel_size=kernel_sz, padding=1)
self.conv8i_4i = nn.Conv3d(i_size * 8, i_size * 4, kernel_size=kernel_sz, padding=1)
self.conv4i_2i = nn.Conv3d(i_size * 4, i_size * 2, kernel_size=kernel_sz, padding=1)
self.conv2i_i = nn.Conv3d(i_size * 2, i_size, kernel_size=kernel_sz, padding=1)
self.convi_1 = nn.Conv3d(i_size, 1, kernel_size=1) # 1x1 conv
self.upconv16i_16i = nn.ConvTranspose3d(i_size * 16, i_size * 16, kernel_size=2, stride=2)
self.upconv8i_8i = nn.ConvTranspose3d(i_size * 8, i_size * 8, kernel_size=2, stride=2)
self.upconv4i_4i = nn.ConvTranspose3d(i_size * 4, i_size * 4, kernel_size=2, stride=2)
self.upconv2i_2i = nn.ConvTranspose3d(i_size * 2, i_size * 2, kernel_size=2, stride=2)
def forward(self, x):
c1 = self.batchNorm_i(self.conv1_i(self.dropout50(x)))
r1 = F.relu(c1)
c2 = self.batchNorm_i(self.convi_i(r1))
r2 = F.relu(c2)
mp1, idxi = self.mp(r2) # 1st max-pooling
c3 = self.batchNorm_2i(self.convi_2i(mp1))
r3 = F.relu(c3)
c4 = self.batchNorm_2i(self.conv2i_2i(r3))
r4 = F.relu(c4)
mp2, idx2i = self.mp(r4) # 2nd max-pooling
c5 = self.batchNorm_4i(self.conv2i_4i(mp2))
r5 = F.relu(c5)
c6 = self.batchNorm_4i(self.conv4i_4i(r5))
r6 = F.relu(c6)
mp3, idx4i = self.mp(r6) # 3rd max-pooling
c7 = self.batchNorm_8i(self.conv4i_8i(mp3))
r7 = F.relu(c7)
c8 = self.batchNorm_8i(self.conv8i_8i(r7))
r8 = F.relu(c8)
mp4, idx8i = self.mp(r8) # 4th max-pooling
c9 = self.batchNorm_16i(self.conv8i_16i(mp4)) # Lowest resolution
r9 = F.relu(c9)
c10 = self.batchNorm_16i(self.conv16i_16i(r9))
r10 = F.relu(c10)
uc1 = self.upconv16i_16i(r10) # 1st upconvolution
c11 = self.batchNorm_8i(self.conv16i_8i(uc1))
r11 = F.relu(c11)
c12 = self.batchNorm_8i(self.conv8i_8i(r11))
r12 = F.relu(c12)
uc2 = self.upconv8i_8i(r12) # 2nd upconvolution
c13 = self.batchNorm_4i(self.conv8i_4i(uc2))
r13 = F.relu(c13)
c14 = self.batchNorm_4i(self.conv4i_4i(r13))
r14 = F.relu(c14)
uc3 = self.upconv4i_4i(r14) # 3rd upconvolution
c15 = self.batchNorm_2i(self.conv4i_2i(uc3))
r15 = F.relu(c15)
c16 = self.batchNorm_2i(self.conv2i_2i(r15))
r16 = F.relu(c16)
uc4 = self.upconv2i_2i(r16) # 4th upconvolution
c17 = self.batchNorm_i(self.conv2i_i(uc4))
r17 = F.relu(c17)
c18 = self.batchNorm_i(self.convi_i(r17))
r18 = F.relu(c18)
c19 = self.convi_1(r18)
return c19
Thanks!