RuntimeError: running_mean should contain 16 elements not 32

I am using a RE NET from GitHub - iMED-Lab/RE-Net: 3D cerebrovascular volume segmentation in Pytorch., and my input data is 448x448x128. When I run the Python program in Google Colab it gives the error RuntimeError: running_mean should contain 16 elements not 32.

The algorithm is here:


nonlinearity = partial(F.relu, inplace=True)

def downsample():
return nn.MaxPool3d(kernel_size=2, stride=2)

def deconv(in_channels, out_channels):
return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

def initialize_weights(*models):
for model in models:
for m in model.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()

class ResEncoder(nn.Module):
def init(self, in_channels, out_channels):
super(ResEncoder, self).init()
self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm3d(out_channels//2)
self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(out_channels//2)
self.relu = nn.ReLU(inplace=False)
self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)

def forward(self, x):
    residual = self.conv1x1(x)
    out = self.relu(self.bn1(self.conv1(x)))
    out = self.relu(self.bn2(self.conv2(out)))
    out += residual
    out = self.relu(out)
    return out

class Decoder(nn.Module):
def init(self, in_channels, out_channels):
super(Decoder, self).init()
self.conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):
    out = self.conv(x)
    return out

class RE_Net(nn.Module):
# def init(self, classes, channels):
def init(self):

    super(RE_Net, self).__init__()
    self.encoder1 = ResEncoder(1, 32)
    self.encoder2 = ResEncoder(32, 64)
    self.encoder3 = ResEncoder(64, 128)
    self.bridge = ResEncoder(128, 256)

    self.conv1_1 = nn.Conv3d(256, 1, kernel_size=1)
    self.conv2_2 = nn.Conv3d(128, 1, kernel_size=1)
    self.conv3_3 = nn.Conv3d(64, 1, kernel_size=1)


    self.convTrans1 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
    self.convTrans2 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
    self.convTrans3 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)



    self.decoder3 = Decoder(256, 128)
    self.decoder2 = Decoder(128, 64)
    self.decoder1 = Decoder(64, 32)
    self.down = downsample()
    self.up3 = deconv(256, 128)
    self.up2 = deconv(128, 64)
    self.up1 = deconv(64, 32)
    self.final = nn.Conv3d(32, 1, kernel_size=1, padding=0)
    initialize_weights(self)

def forward(self, x):
    enc1 = self.encoder1(x)
    down1 = self.down(enc1)

    enc2 = self.encoder2(down1)
    down2 = self.down(enc2)

    con3_3 = self.conv3_3(enc2)
    convTrans3 = self.convTrans3(con3_3)
    x3 = -1 * (torch.sigmoid(convTrans3)) + 1
    x3 = x3.expand(-1, 32, -1, -1, -1).mul(enc1)
    x3 = x3 + enc1

    enc3 = self.encoder3(down2)
    down3 = self.down(enc3)

    con2_2 = self.conv2_2(enc3)
    convTrans2 = self.convTrans2(con2_2)
    x2 = -1 * (torch.sigmoid(convTrans2)) + 1
    x2 = x2.expand(-1, 64, -1, -1, -1).mul(enc2)
    x2 = x2 + enc2



    bridge = self.bridge(down3)

    conv1_1 = self.conv1_1(bridge)
    convTrans1 = self.convTrans1(conv1_1)


    x = -1 * (torch.sigmoid(convTrans1)) + 1
    x = x.expand(-1, 128, -1, -1, -1).mul(enc3)
    x = x + enc3

    up3 = self.up3(bridge)
    up3 = torch.cat((up3, x), dim=1)
    dec3 = self.decoder3(up3)

    up2 = self.up2(dec3)
    up2 = torch.cat((up2, x2), dim=1)
    dec2 = self.decoder2(up2)

    up1 = self.up1(dec2)
    up1 = torch.cat((up1, x3), dim=1)
    dec1 = self.decoder1(up1)

    final = self.final(dec1)
    final = F.sigmoid(final)
    return final

I have researched and found out the issue has something to do with batch normalization, and I have tried to adjust the numbers, but I do not know which line is causing the problem. If anyone could help, that would be great. Thanks!

self.bn2 seems to be wrong:

self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(out_channels//2)

as it’s expecting out_channels//2 features while self.conv2 outputs an activation with out_channels channels. Make sure the number of output channels from conv2 is equal to the expected features in bn2 and it should work.