I am using a RE NET from GitHub - iMED-Lab/RE-Net: 3D cerebrovascular volume segmentation in Pytorch., and my input data is 448x448x128. When I run the Python program in Google Colab it gives the error RuntimeError: running_mean should contain 16 elements not 32.
The algorithm is here:
nonlinearity = partial(F.relu, inplace=True)
def downsample():
return nn.MaxPool3d(kernel_size=2, stride=2)
def deconv(in_channels, out_channels):
return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)
def initialize_weights(*models):
for model in models:
for m in model.modules():
if isinstance(m, nn.Conv3d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal(m.weight)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class ResEncoder(nn.Module):
def init(self, in_channels, out_channels):
super(ResEncoder, self).init()
self.conv1 = nn.Conv3d(in_channels, out_channels//2, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm3d(out_channels//2)
self.conv2 = nn.Conv3d(out_channels//2, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm3d(out_channels//2)
self.relu = nn.ReLU(inplace=False)
self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
residual = self.conv1x1(x)
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out += residual
out = self.relu(out)
return out
class Decoder(nn.Module):
def init(self, in_channels, out_channels):
super(Decoder, self).init()
self.conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm3d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.conv(x)
return out
class RE_Net(nn.Module):
# def init(self, classes, channels):
def init(self):
super(RE_Net, self).__init__()
self.encoder1 = ResEncoder(1, 32)
self.encoder2 = ResEncoder(32, 64)
self.encoder3 = ResEncoder(64, 128)
self.bridge = ResEncoder(128, 256)
self.conv1_1 = nn.Conv3d(256, 1, kernel_size=1)
self.conv2_2 = nn.Conv3d(128, 1, kernel_size=1)
self.conv3_3 = nn.Conv3d(64, 1, kernel_size=1)
self.convTrans1 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
self.convTrans2 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
self.convTrans3 = nn.ConvTranspose3d(1, 1, kernel_size=2, stride=2)
self.decoder3 = Decoder(256, 128)
self.decoder2 = Decoder(128, 64)
self.decoder1 = Decoder(64, 32)
self.down = downsample()
self.up3 = deconv(256, 128)
self.up2 = deconv(128, 64)
self.up1 = deconv(64, 32)
self.final = nn.Conv3d(32, 1, kernel_size=1, padding=0)
initialize_weights(self)
def forward(self, x):
enc1 = self.encoder1(x)
down1 = self.down(enc1)
enc2 = self.encoder2(down1)
down2 = self.down(enc2)
con3_3 = self.conv3_3(enc2)
convTrans3 = self.convTrans3(con3_3)
x3 = -1 * (torch.sigmoid(convTrans3)) + 1
x3 = x3.expand(-1, 32, -1, -1, -1).mul(enc1)
x3 = x3 + enc1
enc3 = self.encoder3(down2)
down3 = self.down(enc3)
con2_2 = self.conv2_2(enc3)
convTrans2 = self.convTrans2(con2_2)
x2 = -1 * (torch.sigmoid(convTrans2)) + 1
x2 = x2.expand(-1, 64, -1, -1, -1).mul(enc2)
x2 = x2 + enc2
bridge = self.bridge(down3)
conv1_1 = self.conv1_1(bridge)
convTrans1 = self.convTrans1(conv1_1)
x = -1 * (torch.sigmoid(convTrans1)) + 1
x = x.expand(-1, 128, -1, -1, -1).mul(enc3)
x = x + enc3
up3 = self.up3(bridge)
up3 = torch.cat((up3, x), dim=1)
dec3 = self.decoder3(up3)
up2 = self.up2(dec3)
up2 = torch.cat((up2, x2), dim=1)
dec2 = self.decoder2(up2)
up1 = self.up1(dec2)
up1 = torch.cat((up1, x3), dim=1)
dec1 = self.decoder1(up1)
final = self.final(dec1)
final = F.sigmoid(final)
return final
I have researched and found out the issue has something to do with batch normalization, and I have tried to adjust the numbers, but I do not know which line is causing the problem. If anyone could help, that would be great. Thanks!