So the error is occurring on line 678 of your code and it’s most likely because out
and x16
aren’t broadcastable. Could you print the size of these Tensors?
Also, you can copy code in by wrapping it in three backticks ```
For example,
x = torch.randn(10,10)
thanks for ur quick response, I could not find the size of these tensors but here are the model.py, train.py, and dataset.py
# model.py
import torch
import torch.nn as nn
# import torch.optim as optim
import torch.nn.functional as F
# import argparse
# import torch.utils.data.sampler as sampler
# from torch.autograd import Variable
# import numpy as np
# from torch.autograd import Variable
class double_conv(nn.Module):
'''(conv => BN => ReLU) * 2'''
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True),
nn.Conv3d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm3d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
y = self.conv(x)
return y
# Encoding block in U-Net
class enc_block(nn.Module):
def __init__(self, in_ch, out_ch, dropout_prob=0.0):
super(enc_block, self).__init__()
self.conv = double_conv(in_ch, out_ch)
self.down = nn.MaxPool3d(2)
self.dropout_prob = dropout_prob
if dropout_prob > 0:
self.dropout = nn.Dropout3d(p=dropout_prob)
def forward(self, x):
y_conv = self.conv(x)
y = self.down(y_conv)
if self.dropout_prob > 0:
y = self.dropout(y)
return y, y_conv
# Decoding block in U-Net
class dec_block(nn.Module):
def __init__(self, in_ch, out_ch, bilinear=True, dropout_prob=0.0):
super(dec_block, self).__init__()
self.conv = double_conv(in_ch, out_ch)
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
else:
self.up = nn.ConvTranspose3d(out_ch, out_ch, 2, stride=2)
self.dropout_prob = dropout_prob
if dropout_prob > 0:
self.dropout = nn.Dropout3d(p=dropout_prob)
def forward(self, x):
y_conv = self.conv(x)
y = self.up(y_conv)
if self.dropout_prob > 0:
y = self.dropout(y)
return y, y_conv
def concatenate(x1, x2):
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
diffY // 2, diffY - diffY//2))
y = torch.cat([x2, x1], dim=1)
return y
class softmax(nn.Module):
def __init__(self, cls_num):
super(softmax, self).__init__()
self.softmax = nn.Softmax(dim=1)
self.cls_num = cls_num
def forward(self, x):
y = torch.zeros_like(x)
for i in range(self.cls_num):
y[:,i*2:i*2+2] = self.softmax(x[:,i*2:i*2+2])
return y
class fuse_conv(nn.Module):
def __init__(self, ch):
super(fuse_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv3d(ch, ch, 1),
nn.ReLU(inplace=True),
)
def forward(self, x):
y = self.conv(x)
return y
# U-Net (baseline)
class UNet(nn.Module):
def __init__(self, in_ch, cls_num, base_ch=64):
super(UNet, self).__init__()
self.in_ch = in_ch
self.cls_num = cls_num
self.base_ch = base_ch
self.enc1 = enc_block(in_ch, base_ch)
self.enc2 = enc_block(base_ch, base_ch*2)
self.enc3 = enc_block(base_ch*2, base_ch*4)
self.enc4 = enc_block(base_ch*4, base_ch*8)
self.dec1 = dec_block(base_ch*8, base_ch*8, bilinear=False)
self.dec2 = dec_block(base_ch*8+base_ch*8, base_ch*4, bilinear=False)
self.dec3 = dec_block(base_ch*4+base_ch*4, base_ch*2, bilinear=False)
self.dec4 = dec_block(base_ch*2+base_ch*2, base_ch, bilinear=False)
self.lastconv = double_conv(base_ch+base_ch, base_ch)
self.outconv = nn.Conv3d(base_ch, cls_num*2, 1)
self.softmax = softmax(cls_num)
def forward(self, x):
enc1, enc1_conv = self.enc1(x)
enc2, enc2_conv = self.enc2(enc1)
enc3, enc3_conv = self.enc3(enc2)
enc4, enc4_conv = self.enc4(enc3)
dec1, _ = self.dec1(enc4)
dec2, _ = self.dec2(concatenate(dec1, enc4_conv))
dec3, _ = self.dec3(concatenate(dec2, enc3_conv))
dec4, _ = self.dec4(concatenate(dec3, enc2_conv))
lastconv = self.lastconv(concatenate(dec4, enc1_conv))
output = self.outconv(lastconv)
output = self.softmax(output)
return output
def description(self):
return 'U-Net with {0:d}-ch input for {1:d}-class segmentation (base channel = {2:d})'.format(self.in_ch, self.cls_num, self.base_ch)
# U-Net with three decoding paths sharing one encoding path
# each decoding path predicts a binary mask of one target organ
class UNet2(nn.Module):
def __init__(self, in_ch, cls_num, base_ch=64):
super(UNet2, self).__init__()
self.in_ch = in_ch
self.cls_num = cls_num
self.base_ch = base_ch
self.enc1 = enc_block(in_ch, base_ch)
self.enc2 = enc_block(base_ch, base_ch*2)
self.enc3 = enc_block(base_ch*2, base_ch*4)
self.enc4 = enc_block(base_ch*4, base_ch*8)
self.dec11 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec12 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec13 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec14 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.lastconv1 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv1 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax1 = nn.Softmax(dim=1)
self.dec21 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec22 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec23 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec24 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.lastconv2 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv2 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax2 = nn.Softmax(dim=1)
self.dec31 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec32 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec33 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec34 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.lastconv3 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv3 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax3 = nn.Softmax(dim=1)
def forward(self, x):
enc1, enc1_conv = self.enc1(x)
enc2, enc2_conv = self.enc2(enc1)
enc3, enc3_conv = self.enc3(enc2)
enc4, enc4_conv = self.enc4(enc3)
dec11, _ = self.dec11(enc4)
dec12, _ = self.dec12(concatenate(dec11, enc4_conv))
dec13, _ = self.dec13(concatenate(dec12, enc3_conv))
dec14, _ = self.dec14(concatenate(dec13, enc2_conv))
lastconv1 = self.lastconv1(concatenate(dec14, enc1_conv))
output1 = self.outconv1(lastconv1)
output1 = self.softmax1(output1)
dec21, _ = self.dec21(enc4)
dec22, _ = self.dec22(concatenate(dec21, enc4_conv))
dec23, _ = self.dec23(concatenate(dec22, enc3_conv))
dec24, _ = self.dec24(concatenate(dec23, enc2_conv))
lastconv2 = self.lastconv2(concatenate(dec24, enc1_conv))
output2 = self.outconv2(lastconv2)
output2 = self.softmax2(output2)
dec31, _ = self.dec31(enc4)
dec32, _ = self.dec32(concatenate(dec31, enc4_conv))
dec33, _ = self.dec33(concatenate(dec32, enc3_conv))
dec34, _ = self.dec34(concatenate(dec33, enc2_conv))
lastconv3 = self.lastconv3(concatenate(dec34, enc1_conv))
output3 = self.outconv3(lastconv3)
output3 = self.softmax3(output3)
output = torch.cat([output1, output2, output3], dim=1)
return output
def description(self):
return 'U-Net with three decoding paths sharing one encoding path [{0:d}-ch input][{1:d}-class seg][base ch={2:d}]'.format(self.in_ch, self.cls_num, self.base_ch)
# model 'UNet2' with intermediate 1x1 conv layers re-weighting features maps from different organ branches (decoding path)
class UNet2_ch1x1(nn.Module):
def __init__(self, in_ch, cls_num, base_ch=64):
super(UNet2_ch1x1, self).__init__()
self.in_ch = in_ch
self.cls_num = cls_num
self.base_ch = base_ch
self.enc1 = enc_block(in_ch, base_ch)
self.enc2 = enc_block(base_ch, base_ch*2)
self.enc3 = enc_block(base_ch*2, base_ch*4)
self.enc4 = enc_block(base_ch*4, base_ch*8)
self.dec11 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec12 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec13 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec14 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.dec21 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec22 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec23 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec24 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.dec31 = dec_block(base_ch*8, base_ch*4, bilinear=False)
self.dec32 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
self.dec33 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
self.dec34 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
self.fuse1 = fuse_conv(base_ch*4*3)
self.fuse2 = fuse_conv(base_ch*2*3)
self.fuse3 = fuse_conv(base_ch*3)
self.fuse4 = fuse_conv((base_ch//2)*3)
self.lastconv1 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv1 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax1 = nn.Softmax(dim=1)
self.lastconv2 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv2 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax2 = nn.Softmax(dim=1)
self.lastconv3 = double_conv(base_ch//2+base_ch, base_ch//2)
self.outconv3 = nn.Conv3d(base_ch//2, 2, 1)
self.softmax3 = nn.Softmax(dim=1)
def forward(self, x):
enc1, enc1_conv = self.enc1(x)
enc2, enc2_conv = self.enc2(enc1)
enc3, enc3_conv = self.enc3(enc2)
enc4, enc4_conv = self.enc4(enc3)
dec11, _ = self.dec11(enc4)
dec21, _ = self.dec21(enc4)
dec31, _ = self.dec31(enc4)
f1 = self.fuse1(torch.cat([dec11, dec21, dec31], dim=1))
[dec11_, dec21_, dec31_] = torch.split(f1, self.base_ch*4, dim=1)
dec12, _ = self.dec12(concatenate(dec11_, enc4_conv))
dec22, _ = self.dec22(concatenate(dec21_, enc4_conv))
dec32, _ = self.dec32(concatenate(dec31_, enc4_conv))
f2 = self.fuse2(torch.cat([dec12, dec22, dec32], dim=1))
[dec12_, dec22_, dec32_] = torch.split(f2, self.base_ch*2, dim=1)
dec13, _ = self.dec13(concatenate(dec12_, enc3_conv))
dec23, _ = self.dec23(concatenate(dec22_, enc3_conv))
dec33, _ = self.dec33(concatenate(dec32_, enc3_conv))
f3 = self.fuse3(torch.cat([dec13, dec23, dec33], dim=1))
[dec13_, dec23_, dec33_] = torch.split(f3, self.base_ch, dim=1)
dec14, _ = self.dec14(concatenate(dec13_, enc2_conv))
dec24, _ = self.dec24(concatenate(dec23_, enc2_conv))
dec34, _ = self.dec34(concatenate(dec33_, enc2_conv))
f4 = self.fuse4(torch.cat([dec14, dec24, dec34], dim=1))
[dec14_, dec24_, dec34_] = torch.split(f4, self.base_ch//2, dim=1)
lastconv1 = self.lastconv1(concatenate(dec14_, enc1_conv))
output1 = self.outconv1(lastconv1)
output1 = self.softmax1(output1)
lastconv2 = self.lastconv2(concatenate(dec24_, enc1_conv))
output2 = self.outconv2(lastconv2)
output2 = self.softmax2(output2)
lastconv3 = self.lastconv3(concatenate(dec34_, enc1_conv))
output3 = self.outconv3(lastconv3)
output3 = self.softmax3(output3)
output = torch.cat([output1, output2, output3], dim=1)
return output
def description(self):
return '<UNet2> with intermediate 1x1 conv layers re-weighting features maps from different organ branches (decoding path) [{0:d}-ch input][{1:d}-class seg][base ch={2:d}]'.format(self.in_ch, self.cls_num, self.base_ch)
#################################### TOD Net #####################################
# modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
class ResnetBlock(nn.Module):
def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm3d, use_dropout=False, use_bias=False):
super(ResnetBlock, self).__init__()
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
conv_block = []
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad3d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad3d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
p = 0
if padding_type == 'reflect':
conv_block += [nn.ReflectionPad3d(1)]
elif padding_type == 'replicate':
conv_block += [nn.ReplicationPad3d(1)]
elif padding_type == 'zero':
p = 1
else:
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
norm_layer(dim)]
return nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
'''WGAN discriminator'''
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.out_shape = 32
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(32),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.AvgPool3d(2)
)
############################################################
self.linear = nn.Sequential(
torch.nn.Linear(128 * 4 * self.out_shape ** 2, 1),
)
def forward(self, x):
# Convolutional layers
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(-1, 128 * 4 * self.out_shape ** 2)
x = self.linear(x)
return x
#################################### WGAN #####################################
'''WGAN generator'''
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv3d(1, 32, kernel_size=3, padding=1),
nn.InstanceNorm3d(32),
nn.LeakyReLU(0.2, inplace=True)
)
# torch.nn.init.
self.conv2 = nn.Sequential(
nn.Conv3d(32, 64, kernel_size=3, padding=1),
nn.InstanceNorm3d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv3d(64, 128, kernel_size=3, padding=1),
nn.InstanceNorm3d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv4 = nn.Sequential(
nn.Conv3d(128, 256, kernel_size=3, padding=1),
nn.InstanceNorm3d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.deConv1_1 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
self.deConv1 = nn.Sequential(
nn.InstanceNorm3d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.deConv2_1 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
self.deConv2 = nn.Sequential(
nn.InstanceNorm3d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.deConv3_1 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
self.deConv3 = nn.Sequential(
nn.InstanceNorm3d(32),
nn.LeakyReLU(0.2, inplace=True)
)
self.deConv4_1 = nn.Conv3d(32, 1, kernel_size=3, padding=1)
self.deConv4 = nn.Tanh() #nn.ReLU()
def forward(self, input):
conv1 = self.conv1(input)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
x = self.deConv1_1(conv4)
x = x + conv3
deConv1 = self.deConv1(x)
x = self.deConv2_1(deConv1)
x += conv2
deConv2 = self.deConv2(x)
x = self.deConv3_1(deConv2)
x += conv1
deConv3 = self.deConv3(x)
x = self.deConv4_1(deConv3)
x += input
output = self.deConv4(x)
return output
'''WGAN discriminator'''
'''
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.out_shape = 32
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(32),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
nn.BatchNorm3d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.AvgPool3d(2)
)
############################################################
self.linear = nn.Sequential(
torch.nn.Linear(128 * 4 * self.out_shape ** 2, 1),
# nn.LeakyReLU(inplace=True)
)
def forward(self, x):
# Convolutional layers
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
# Flatten and apply sigmoid
x = x.view(-1, 128 * 4 * self.out_shape ** 2)
x = self.linear(x)
return x
'''
#################################### VoxResNet #####################################
import torch
import torch.nn as nn
class softmax(nn.Module):
def __init__(self, cls_num):
super(softmax, self).__init__()
self.softmax = nn.Softmax(dim=1)
self.cls_num = cls_num
def forward(self, x):
y = torch.zeros_like(x)
for i in range(self.cls_num):
y[:,i*2:i*2+2] = self.softmax(x[:,i*2:i*2+2])
return y
class VoxRes(nn.Module):
def __init__(self, in_channel):
super(VoxRes, self).__init__()
self.block = nn.Sequential(
nn.BatchNorm3d(in_channel),
nn.ReLU(),
nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1),
nn.BatchNorm3d(in_channel),
nn.ReLU(),
nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1)
)
def forward(self, x):
return self.block(x)+x
class VoxResNet(nn.Module):
def __init__(self, in_channels, num_class):
super(VoxResNet, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU(),
nn.Conv3d(32, 32, kernel_size=3, padding=1)
)
self.conv2 = nn.Sequential(
nn.BatchNorm3d(32),
nn.ReLU(),
nn.Conv3d(32, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
VoxRes(64),
VoxRes(64)
)
self.conv3 = nn.Sequential(
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
VoxRes(64),
VoxRes(64)
)
self.conv4 = nn.Sequential(
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
VoxRes(64),
VoxRes(64)
)
self.deconv_c1 = nn.Sequential(
nn.ConvTranspose3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)),
nn.Conv3d(32, num_class*2, kernel_size=1))
self.deconv_c2 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2)),
nn.Conv3d(64, num_class*2, kernel_size=1))
self.deconv_c3 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)),
nn.Conv3d(64, num_class*2, kernel_size=1))
self.deconv_c4 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 8, 8), stride=(1, 8, 8)),
nn.Conv3d(64, num_class*2, kernel_size=1))
# self.outconv = nn.Conv3d(num_class, num_class, 1)
self.softmax = softmax(num_class)
def forward(self, x):
out1 = self.conv1(x)
out2 = self.conv2(out1)
out3 = self.conv3(out2)
out4 = self.conv4(out3)
c1 = self.deconv_c1(out1)
c2 = self.deconv_c2(out2)
c3 = self.deconv_c3(out3)
c4 = self.deconv_c4(out4)
return self.softmax(c1+c2+c3+c4)
# missing classes I have added
def ELUCons(elu, nchan):
if elu:
return nn.ELU(inplace=True)
else:
return nn.PReLU(nchan)
def passthrough(x, **kwargs):
return x
def _make_nConv(nchan, depth, elu):
layers = []
for _ in range(depth):
layers.append(LUConv(nchan, elu))
return nn.Sequential(*layers)
class ContBatchNorm3d(nn.BatchNorm3d):
def _check_input_dim(self, input):
if input.dim() != 5:
raise ValueError('expected 5D input (got {}D input)'
.format(input.dim()))
super(ContBatchNorm3d, self)._check_input_dim(input)
def forward(self, input):
self._check_input_dim(input)
return F.batch_norm(
input, self.running_mean, self.running_var, self.weight, self.bias,
True, self.momentum, self.eps)
class LUConv(nn.Module):
def __init__(self, nchan, elu):
super(LUConv, self).__init__()
self.relu1 = ELUCons(elu, nchan)
self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(nchan)
def forward(self, x):
out = self.relu1(self.bn1(self.conv1(x)))
return out
class InputTransition(nn.Module):
def __init__(self, outChans, elu):
super(InputTransition, self).__init__()
self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(16)
self.relu1 = ELUCons(elu, 16)
def forward(self, x):
# do we want a PRELU here as well?
out = self.bn1(self.conv1(x))
# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
x, x, x, x, x, x, x, x), 0)
out = self.relu1(torch.add(out, x16))
return out
class DownTransition(nn.Module):
def __init__(self, inChans, nConvs, elu, dropout=False):
super(DownTransition, self).__init__()
outChans = 2*inChans
self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
self.bn1 = ContBatchNorm3d(outChans)
self.do1 = passthrough
self.relu1 = ELUCons(elu, outChans)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x):
down = self.relu1(self.bn1(self.down_conv(x)))
out = self.do1(down)
out = self.ops(out)
out = self.relu2(torch.add(out, down))
return out
class UpTransition(nn.Module):
def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
super(UpTransition, self).__init__()
self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
self.bn1 = ContBatchNorm3d(outChans // 2)
self.do1 = passthrough
self.do2 = nn.Dropout3d()
self.relu1 = ELUCons(elu, outChans // 2)
self.relu2 = ELUCons(elu, outChans)
if dropout:
self.do1 = nn.Dropout3d()
self.ops = _make_nConv(outChans, nConvs, elu)
def forward(self, x, skipx):
out = self.do1(x)
skipxdo = self.do2(skipx)
out = self.relu1(self.bn1(self.up_conv(out)))
xcat = torch.cat((out, skipxdo), 1)
out = self.ops(xcat)
out = self.relu2(torch.add(out, xcat))
return out
class OutputTransition(nn.Module):
def __init__(self, inChans, elu, nll):
super(OutputTransition, self).__init__()
self.conv1 = nn.Conv3d(inChans, 2, kernel_size=5, padding=2)
self.bn1 = ContBatchNorm3d(2)
self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
self.relu1 = ELUCons(elu, 2)
if nll:
self.softmax = F.log_softmax
else:
self.softmax = F.softmax
def forward(self, x):
# convolve 32 down to 2 channels
out = self.relu1(self.bn1(self.conv1(x)))
out = self.conv2(out)
# make channels the last axis
out = out.permute(0, 2, 3, 4, 1).contiguous()
# flatten
out = out.view(out.numel() // 2, 2)
out = self.softmax(out)
# treat channel 0 as the predicted output
return out
#################################### VNet #####################################
class VNet(nn.Module):
# the number of convolutions in each layer corresponds
# to what is in the actual prototxt, not the intent
def __init__(self, elu=True, nll=False):
super(VNet, self).__init__()
self.in_tr = InputTransition(16, elu)
self.down_tr32 = DownTransition(16, 1, elu)
self.down_tr64 = DownTransition(32, 2, elu)
self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
self.up_tr64 = UpTransition(128, 64, 1, elu)
self.up_tr32 = UpTransition(64, 32, 1, elu)
self.out_tr = OutputTransition(32, elu, nll)
def forward(self, x):
out16 = self.in_tr(x)
out32 = self.down_tr32(out16)
out64 = self.down_tr64(out32)
out128 = self.down_tr128(out64)
out256 = self.down_tr256(out128)
out = self.up_tr256(out256, out128)
out = self.up_tr128(out, out64)
out = self.up_tr64(out, out32)
out = self.up_tr32(out, out16)
out = self.out_tr(out)
return out
train.py
# !nvidia-smi
# import sys
import os
import numpy as np
# import random
import time
# os.environ['CUDA_VISIBLE_DEVICES'] = "3"
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils import data
# torch.cuda.current_device()
# from dataset import *
from dataset import create_folds, DatasetStk
from model import VNet
# VoxRes, VoxResNet, UNet, UNet2, UNet2_ch1x1, ResUNet, ResnetBlock
from loss import DiceLoss
from utils import init_model, save_model
from metric import eval
# import math
# torch.cuda.empty_cache()
class WGAN_De_Att:
def __init__(self, netV=None):
# initialize networks
if netV is None:
self.netV = init_model(VNet())
else:
self.netV = netV
# initialize optimizers
self.optimizer_V = optim.SGD(self.netV.parameters(), lr=1e-2, momentum=0.99, weight_decay=1e-8)
def scheduler(self, epoch, lr0=0.01):
if epoch>=80:
lr = lr0*0.1
for param_group in self.optimizer_V.param_groups:
param_group['lr'] = lr
if epoch>100:
lr = lr0*0.01
for param_group in self.optimizer_V.param_groups:
param_group['lr'] = lr
def batch_train(self, cfg, dl_train, epoch_id):
self.netV.train()
criterion = DiceLoss()
# for this epoch
epoch_loss = np.zeros(cfg['cls_num'], dtype=np.float)
epoch_loss_num = np.zeros(cfg['cls_num'], dtype=np.int64)
for batch_id, batch in enumerate(dl_train):
image = batch['data']
label = batch['label']
flag = batch['label_exist']
n = len(image)
image, label = image.cuda(), label.cuda()
print_line = 'Epoch {0:d}/{1:d} (train) --- Progress {2:5.2f}% (+{3:d})'.format(
epoch_id+1, cfg['epoch_num'], 100.0 * batch_id * cfg['batch_size'] / len(d_train), n)
# occupy_mem_allgpus()
### UNet loss
'''---Dice loss---'''
loss = 0
cls_loss = np.zeros(cfg['cls_num'], dtype=np.float)
pred = self.netV(image)
for c in range(cfg['cls_num']):
if torch.sum(flag[:,c]) > 0:
l = criterion(pred[:,c*2:c*2+2], label[:,c*2:c*2+2], flag[:,c])
loss += l
cls_loss[c] = l.item()
epoch_loss[c] += cls_loss[c]
epoch_loss_num[c] += 1
else:
cls_loss[c] = 0
print_line += ' -- Dice Loss: {0:.4f}'.format(loss.item())
print(print_line)
self.netV.zero_grad()
loss.backward()
self.optimizer_V.step()
del image, label, pred, loss
train_loss = np.sum(epoch_loss)/np.sum(epoch_loss_num)
epoch_loss = epoch_loss / epoch_loss_num
print_line = 'Epoch {0:d}/{1:d} (train) --- Loss: {2:.6f} ({3:s})\n'.format(
epoch_id+1, cfg['epoch_num'], train_loss, '/'.join(['%.6f']*len(epoch_loss)) % tuple(epoch_loss))
print(print_line)
#torch.cuda.empty_cache()
def batch_eval(self, cfg, dl_val, epoch_id, loss_fn, eval_test=False):
self.netV.eval()
criterion = DiceLoss()
output_mask = []
for c in range(cfg['cls_num']):
output_mask.append(None)
for batch_id, batch in enumerate(dl_val):
image = batch['data']
label = batch['label']
flag = batch['label_exist']
n = len(image)
image = image.cuda()
label = label.cuda()
pred = self.netV(image)
# print('------------------\n')
# for ii in range(n):
# print(torch.max(label[ii,1,:]), torch.min(label[ii,1,:]))
print_line = 'Epoch {0:d}/{1:d} (val) --- Progress {2:5.2f}% (+{3:d})'.format(
epoch_id+1, cfg['epoch_num'], 100.0 * batch_id * cfg['batch_size'] / len(d_val), n)
print(print_line)
# occupy_mem_allgpus()
for c in range(cfg['cls_num']):
pred_bin = torch.argmax(pred[:,c*2:c*2+2], dim=1, keepdim=True)
for i in range(n):
if flag[i, c] > 0:
mask = pred_bin[i,:].contiguous().cpu().numpy().copy().astype(dtype=np.uint8)
mask = np.squeeze(mask)
# print('-----------\n')
# print(np.max(mask), np.min(mask))
mask = resample_array(
mask, batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(),
batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy())
if output_mask[c] is None:
output_mask[c] = mask
else:
output_mask[c] = output_mask[c] + mask
if batch['eof'][i]:
output_mask[c][output_mask[c] > 0] = 1
output2file(
output_mask[c], batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(),
'{}/{}@{}@{}.nii.gz'.format(val_result_path, batch['dataset'][i], batch['case'][i], c+1))
output_mask[c] = None
del pred_bin
del image, pred
if eval_test is True:
dsc, asd, dsc_m, asd_m = eval(
pd_path=val_result_path, gt_entries=test_fold, label_map=cfg['label_map'], cls_num=cfg['cls_num'],
metric_fn='metric_{0:04d}'.format(epoch_id), calc_asd=False)
else:
dsc, asd, dsc_m, asd_m = eval(
pd_path=val_result_path, gt_entries=val_fold, label_map=cfg['label_map'], cls_num=cfg['cls_num'],
metric_fn='metric_{0:04d}'.format(epoch_id), calc_asd=False)
print_line = 'Epoch {0:d}/{1:d} (val) --- DSC {2:.2f} ({3:s})% --- ASD {4:.2f} ({5:s})mm'.format(
epoch_id+1, cfg['epoch_num'],
dsc_m*100.0, '/'.join(['%.2f']*len(dsc[:,0])) % tuple(dsc[:,0]*100.0),
asd_m, '/'.join(['%.2f']*len(asd[:,0])) % tuple(asd[:,0]))
print(print_line)
def Net_train(self, cfg, dl_train, dl_val, loss_fn, save_dir, if_eval=False):
# check
# self.batch_eval(cfg, dl_val, 0, loss_fn)
for epoch_id in range(cfg['epoch_num']):
t0 = time.perf_counter()
self.scheduler(epoch_id)
# train
self.batch_train(cfg, dl_train, epoch_id)
# evaluation
if if_eval is True:
self.batch_eval(cfg, dl_val, epoch_id, loss_fn)
# save model
file_name_V = 'V_epoch{}'.format(epoch_id)+'.pth'
save_model(self.netV, save_dir, file_name_V)
# eval
if (epoch_id+1)%1==0:
self.batch_eval(cfg, dl_val, epoch_id, loss_fn)
if __name__ == '__main__':
cfg = {}
cfg['cls_num'] = 1
cfg['gpu'] = '3' # to use multiple gpu: cfg['gpu'] = '0,1,2,3'
cfg['fold_num'] = 5
cfg['epoch_num'] = 120
cfg['batch_size'] = 6
cfg['lr'] = 0.001
cfg['Unet_path'] = ' '
cfg['model_path'] = ' '
cfg['rs_size'] = [256,256,32] # resample size: [x, y, z]
cfg['rs_spacing'] = [1.5,1.5,3.0] # resample spacing: [x, y, z]. non-positive value means adaptive spacing fit the physical size: rs_size * rs_spacing = origin_size * origin_spacing
cfg['rs_intensity'] = [-200.0, 200.0] # rescale intensity from [min, max] to [0, 1].
cfg['cpu_thread'] = 8 # multi-thread for data loading. zero means single thread.
# list of dataset names and paths
cfg['data_path_train'] = [
[
# ['KiTS', '/home/41/hh19/data_new2']
]
cfg['data_path_test'] = [
['BTCV', ' ']
]
cfg['label_map'] = {
'KiTS':{1:1}
#'BTCV':{2:1}
}
# exclude any samples in the form of '[dataset_name, case_name]'
cfg['exclude_case'] = [
['KiTS', 'case_00133'],
['KiTS', 'case_00134'],
['KiTS', 'case_00135'],
['KiTS', 'case_00136'],
['KiTS', 'case_00137'],
['KiTS', 'case_00138'],
['KiTS', 'case_00139'],
['KiTS', 'case_00140'],
['KiTS', 'case_00141'],
['KiTS', 'case_00142'],
# ['LiTS', 'volume-102']
]
os.environ["CUDA_VISIBLE_DEVICES"] = cfg['gpu']
train_start_time = time.localtime()
time_stamp = time.strftime("%Y%m%d%H%M%S", train_start_time)
# acc_time = 0
# create directory for results storage
store_dir = '{}/model_{}'.format(cfg['model_path'], time_stamp)
os.makedirs(store_dir, exist_ok=True)
best_model_fn = '{}/epoch_{}.pth.tar'.format(store_dir, 1)
loss_fn = '{}/loss.txt'.format(store_dir)
log_fn = '{}/log.txt'.format(store_dir)
val_result_path = '{}/results_val'.format(store_dir)
os.makedirs(val_result_path, exist_ok=True)
test_result_path = '{}/results_test'.format(store_dir)
os.makedirs(test_result_path, exist_ok=True)
# Dataloader
folds, _ = create_folds(data_path=cfg['data_path_train'], fold_num=cfg['fold_num'], exclude_case=cfg['exclude_case'])
'''create training and validation fold'''
train_fold = []
for i in range(cfg['fold_num'] - 2):
train_fold.extend(folds[i])
# d_train = Dataset(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_train = DatasetStk(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=True)
dl_train = data.DataLoader(dataset=d_train, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True, drop_last=True, num_workers=cfg['cpu_thread'])
# create validaion fold
val_fold = folds[cfg['fold_num']-2]
# d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_val = DatasetStk(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=False)
dl_val = data.DataLoader(dataset=d_val, batch_size=cfg['batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])
# create test fold
test_fold = folds[cfg['fold_num']-1]
#d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_test = DatasetStk(test_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=False)
dl_test = data.DataLoader(dataset=d_test, batch_size=cfg['batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])
'''--- train downstream tasks ---'''
Solver = WGAN_De_Att()
Solver.Net_train(cfg, dl_train, dl_val, loss_fn, store_dir)
'''--- eval downstream tasks ---'''
ii = 119
file_path = ' '
model_path = os.path.join(file_path,'V_epoch'+str(ii)+'.pth')
model = VNet()
model.cuda()
model.load_state_dict(torch.load(model_path))
netU = nn.DataParallel(module=model)
netU.eval()
Solver = WGAN_De_Att(netU)
Solver.batch_eval(cfg=cfg,
dl_val=dl_test,
epoch_id=0,
loss_fn=loss_fn,
eval_test=True)
dataset.py
import os
import sys
import torch
from torch.utils import data
import itk
import numpy as np
import random
import SimpleITK as sitk
def read_image(fname, imtype):
reader = itk.ImageFileReader[imtype].New()
reader.SetFileName(fname)
reader.Update()
image = reader.GetOutput()
return image
def scan_path(d_name, d_path):
entries = []
if d_name == 'LiTS':
for f in os.listdir(d_path):
if f.startswith('volume-') and f.endswith('.mha'):
id = int(f.split('.mha')[0].split('volume-')[1])
if os.path.isfile('{}/segmentation-{}.mha'.format(d_path, id)):
case_name = 'volume-{}'.format(id)
image_name = '{}/volume-{}.mha'.format(d_path, id)
label_name = '{}/segmentation-{}.mha'.format(d_path, id)
entries.append([d_name, case_name, image_name, label_name])
elif d_name == 'KiTS':
for case_name in os.listdir(d_path):
image_name = '{}/{}/imaging.nii.gz'.format(d_path, case_name)
image_ld_name = '{}/{}/imaging_ld.nii.gz'.format(d_path, case_name)
label_name = '{}/{}/segmentation.nii.gz'.format(d_path, case_name)
if os.path.isfile(image_name) and os.path.isfile(label_name) and os.path.isfile(image_ld_name):
entries.append([d_name, case_name, image_name, image_ld_name, label_name])
elif d_name == 'BTCV':
for f in os.listdir(d_path):
if f.startswith('volume-'):
id = int(f.split('.nii')[0].split('volume-')[1])
if os.path.isfile('{}/segmentation-{}.nii.gz'.format(d_path, id)):
case_name = 'volume-{}'.format(id)
image_name = '{}/volume-{}.nii.gz'.format(d_path, id)
label_name = '{}/segmentation-{}.nii.gz'.format(d_path, id)
entries.append([d_name, case_name, image_name, label_name])
elif d_name == 'spleen':
for f in os.listdir('{}/imagesTr'.format(d_path)):
if f.startswith('spleen_'):
id = int(f.split('.nii.gz')[0].split('spleen_')[1])
if os.path.isfile('{}/labelsTr/{}'.format(d_path, f)):
case_name = 'spleen_{}'.format(id)
image_name = '{}/imagesTr/{}'.format(d_path, f)
label_name = '{}/labelsTr/{}'.format(d_path, f)
entries.append([d_name, case_name, image_name, label_name])
return entries
def create_folds(data_path, fold_num, exclude_case):
fold_file_name = '{0:s}/CV_{1:d}-fold.txt'.format(sys.path[0], fold_num)
folds = {}
if os.path.exists(fold_file_name):
with open(fold_file_name, 'r') as fold_file:
strlines = fold_file.readlines()
for strline in strlines:
strline = strline.rstrip('\n')
params = strline.split()
fold_id = int(params[0])
if fold_id not in folds:
folds[fold_id] = []
folds[fold_id].append([params[1], params[2], params[3], params[5], params[4]])
else:
entries = []
for [d_name, d_path] in data_path:
entries.extend(scan_path(d_name, d_path))
for e in entries:
if e[0:2] in exclude_case:
entries.remove(e)
case_num = len(entries)
fold_size = int(case_num / fold_num)
random.shuffle(entries)
for fold_id in range(fold_num - 1):
folds[fold_id] = entries[fold_id * fold_size:(fold_id + 1) * fold_size]
folds[fold_num - 1] = entries[(fold_num - 1) * fold_size:]
with open(fold_file_name, 'w') as fold_file:
for fold_id in range(fold_num):
print(fold_id)
for [d_name, casename, image_fn, image_ld_fn, label_fn] in folds[fold_id]:
print(d_name, casename, image_fn, image_ld_fn, label_fn)
fold_file.write('{0:d} {1:s} {2:s} {3:s} {4:s} {5:s}\n'.format(fold_id, d_name, casename, image_fn, image_ld_fn, label_fn))
folds_size = [len(x) for x in folds.values()]
return folds, folds_size
def normalize(x, min, max):
factor = 1.0 / (max - min)
x[x < min] = min
x[x > max] = max
x = (x - min) * factor
return x
def generate_transform(identity):
if identity:
t = itk.IdentityTransform[itk.D, 3].New()
else:
min_rotate = -0.05 # [rad]
max_rotate = 0.05 # [rad]
min_offset = -5.0 # [mm]
max_offset = 5.0 # [mm]
t = itk.Euler3DTransform[itk.D].New()
euler_parameters = t.GetParameters()
euler_parameters = itk.OptimizerParameters[itk.D](t.GetNumberOfParameters())
euler_parameters[0] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate
euler_parameters[1] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate
euler_parameters[2] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate
euler_parameters[3] = min_offset + random.random() * (max_offset - min_offset) # tranlate
euler_parameters[4] = min_offset + random.random() * (max_offset - min_offset) # tranlate
euler_parameters[5] = min_offset + random.random() * (max_offset - min_offset) # tranlate
t.SetParameters(euler_parameters)
return t
def resample(image, imtype, size, spacing, origin, transform, linear, dtype):
o_origin = image.GetOrigin()
o_spacing = image.GetSpacing()
o_size = image.GetBufferedRegion().GetSize()
output = {}
output['org_size'] = np.array(o_size, dtype=int)
output['org_spacing'] = np.array(o_spacing, dtype=float)
output['org_origin'] = np.array(o_origin, dtype=float)
if origin is None: # if no origin point specified, center align the resampled image with the original image
new_size = np.zeros(3, dtype=int)
new_spacing = np.zeros(3, dtype=float)
new_origin = np.zeros(3, dtype=float)
for i in range(3):
new_size[i] = size[i]
if spacing[i] > 0:
new_spacing[i] = spacing[i]
new_origin[i] = o_origin[i] + o_size[i]*o_spacing[i]*0.5 - size[i]*spacing[i]*0.5
else:
new_spacing[i] = o_size[i] * o_spacing[i] / size[i]
new_origin[i] = o_origin[i]
else:
new_size = np.array(size, dtype=int)
new_spacing = np.array(spacing, dtype=float)
new_origin = np.array(origin, dtype=float)
output['size'] = new_size
output['spacing'] = new_spacing
output['origin'] = new_origin
resampler = itk.ResampleImageFilter[imtype, imtype].New()
resampler.SetInput(image)
resampler.SetSize((int(new_size[0]), int(new_size[1]), int(new_size[2])))
resampler.SetOutputSpacing((float(new_spacing[0]), float(new_spacing[1]), float(new_spacing[2])))
resampler.SetOutputOrigin((float(new_origin[0]), float(new_origin[1]), float(new_origin[2])))
resampler.SetTransform(transform)
if linear:
resampler.SetInterpolator(itk.LinearInterpolateImageFunction[imtype, itk.D].New())
else:
resampler.SetInterpolator(itk.NearestNeighborInterpolateImageFunction[imtype, itk.D].New())
resampler.SetDefaultPixelValue(int(np.min(itk.GetArrayFromImage(image))))
resampler.Update()
rs_image = resampler.GetOutput()
image_array = itk.GetArrayFromImage(rs_image)
image_array = image_array[np.newaxis, :].astype(dtype)
output['array'] = image_array
return output
def make_onehot(input, cls):
oh = np.repeat(np.zeros_like(input), cls * 2, axis=0)
for i in range(cls):
tmp = np.zeros_like(input)
tmp[input == i + 1] = 1
oh[i * 2 + 0 , :] = 1 - tmp
oh[i * 2 + 1 , :] = tmp
return oh
def make_flag(cls, labelmap):
flag = np.zeros([cls, 1])
for key in labelmap:
flag[labelmap[key] - 1 , 0] = 1
return flag
def image2file(image, imtype, fname):
writer = itk.ImageFileWriter[imtype].New()
writer.SetInput(image)
writer.SetFileName(fname)
writer.Update()
def array2file(array, size, origin, spacing, imtype, fname):
image = itk.GetImageFromArray(array.reshape([size[2], size[1], size[0]]))
image.SetSpacing((spacing[0], spacing[1], spacing[2]))
image.SetOrigin((origin[0], origin[1], origin[2]))
image2file(image, imtype=imtype, fname=fname)
# dataset of 3D image volume
# 3D volumes are resampled from and center-aligned with the original images
class Dataset(data.Dataset):
def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num):
self.ImageType = itk.Image[itk.SS, 3]
self.LabelType = itk.Image[itk.UC, 3]
self.ids = ids
self.rs_size = rs_size
self.rs_spacing = rs_spacing
self.rs_intensity = rs_intensity
self.label_map = label_map
self.cls_num = cls_num
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
[d_name, casename, image_fn, image_ld_fn, label_fn] = self.ids[index]
t = generate_transform(identity=True)
src_image = read_image(fname=image_fn, imtype=self.ImageType)
image = resample(
image=src_image, imtype=self.ImageType,
size=self.rs_size, spacing=self.rs_spacing, origin=None,
transform=t, linear=True, dtype=np.float32)
image['array'] = normalize(image['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])
src_image_ld = read_image(fname=image_ld_fn, imtype=self.ImageType)
image_ld = resample(
image=src_image_ld, imtype=self.ImageType,
size=self.rs_size, spacing=self.rs_spacing, origin=None,
transform=t, linear=True, dtype=np.float32)
image_ld['array'] = normalize(image_ld['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])
src_label = read_image(fname=label_fn, imtype=self.LabelType)
label = resample(
image=src_label, imtype=self.LabelType,
size=self.rs_size, spacing=self.rs_spacing, origin=None,
transform=t, linear=False, dtype=np.int64)
tmp_array = np.zeros_like(label['array'])
lmap = self.label_map[d_name]
for key in lmap:
tmp_array[label['array'] == key] = lmap[key]
label['array'] = tmp_array
label_bin = make_onehot(label['array'], cls=self.cls_num)
label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name])
image_tensor = torch.from_numpy(image['array'])
image_ld_tensor = torch.from_numpy(image_ld['array'])
label_tensor = torch.from_numpy(label_bin)
output = {}
output['data'] = image_tensor
output['data_ld'] = image_ld_tensor
output['label'] = label_tensor
output['label_exist'] = label_exist
output['dataset'] = d_name
output['case'] = casename
output['size'] = image['size']
output['spacing'] = image['spacing']
output['origin'] = image['origin']
output['org_size'] = image['org_size']
output['org_spacing'] = image['org_spacing']
output['org_origin'] = image['org_origin']
output['eof'] = True
return output
# dataset of image stacks (short-length 3D image volume)
# each image is resampled as a series adjacent image stacks
# the image stacks cover the whole range of image length
class DatasetStk(data.Dataset):
def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num, perturb):
self.ImageType = itk.Image[itk.SS, 3]
self.LabelType = itk.Image[itk.UC, 3]
self.ids = []
self.rs_size = rs_size
self.rs_spacing = rs_spacing
self.rs_intensity = rs_intensity
self.label_map = label_map
self.cls_num = cls_num
self.perturb = perturb
for i, [d_name, casename, image_fn, image_ld_fn, label_fn] in enumerate(ids):
print('Preparing image stacks ({}/{}) ...'.format(i, len(ids)))
reader = sitk.ImageFileReader()
reader.SetFileName(image_fn)
reader.ReadImageInformation()
size = reader.GetSize()
spacing = reader.GetSpacing()
origin = reader.GetOrigin()
stack_len = rs_size[2] * rs_spacing[2]
'''
image_len = size[2] * spacing[2]
stack_num = int(image_len / stack_len) + 1
for stack_id in range(stack_num):
stack_size = np.array(rs_size, dtype=int)
stack_spacing = np.array(rs_spacing, dtype=float)
stack_origin = np.zeros(3, dtype=float)
stack_origin[0] = origin[0] + 0.5 * size[0] * spacing[0] - 0.5 * rs_size[0] * rs_spacing[0]
stack_origin[1] = origin[1] + 0.5 * size[1] * spacing[1] - 0.5 * rs_size[1] * rs_spacing[1]
#stack_origin[2] = origin[2] - 0.5 * (stack_num * stack_len - image_len) + stack_id * stack_len
stack_perturb = np.zeros(2, dtype=float)
if stack_num > 1:
stack_origin[2] = origin[2] + stack_id * (image_len - stack_len) / (stack_num - 1)
stack_perturb[0] = max(stack_origin[2] - 0.5 * stack_len, origin[2])
stack_perturb[1] = min(stack_origin[2] + 0.5 * stack_len, origin[2] + image_len - stack_len)
else:
stack_origin[2] = origin[2] + 0.5 * (image_len - stack_len)
stack_perturb[0] = stack_origin[2]
stack_perturb[1] = stack_origin[2]
self.ids.append([d_name, casename, image_fn, label_fn, stack_id, stack_size, stack_spacing, stack_origin, stack_perturb, stack_id == stack_num-1])
'''
lb_reader = sitk.ImageFileReader()
lb_reader.SetFileName(label_fn)
lb_volume = lb_reader.Execute()
lb_array = sitk.GetArrayFromImage(lb_volume)
tmp_array = np.zeros_like(lb_array)
lmap = self.label_map[d_name]
for key in lmap:
tmp_array[lb_array == key] = lmap[key]
lb_array = tmp_array
nz_ind = np.nonzero(lb_array > 0)
lb_size = np.zeros(3, dtype=np.float)
lb_origin = np.zeros(3, dtype=np.float)
#for i in range(3):
# lb_size[i] = (np.max(nz_ind[2-i]) - np.min(nz_ind[2-i]) + 1) * spacing[i]
# lb_origin[i] = origin[i] + np.min(nz_ind[2-i]) * spacing[i]
lb_size[2] = (np.max(nz_ind[0]) - np.min(nz_ind[0]) + 1) * spacing[2]
lb_origin[2] = origin[2] + np.min(nz_ind[0]) * spacing[2]
lb_len = lb_size[2]
if lb_len > stack_len:
stack_num = int((lb_len + stack_len) / stack_len) + 1
else:
stack_num = 1
for stack_id in range(stack_num):
stack_size = np.array(rs_size, dtype=int)
stack_spacing = np.array(rs_spacing, dtype=float)
stack_origin = np.zeros(3, dtype=float)
stack_origin[0] = origin[0] + 0.5 * size[0] * spacing[0] - 0.5 * rs_size[0] * rs_spacing[0]
stack_origin[1] = origin[1] + 0.5 * size[1] * spacing[1] - 0.5 * rs_size[1] * rs_spacing[1]
# stack_origin[2] = origin[2] - 0.5 * (stack_num * stack_len - image_len) + stack_id * stack_len
stack_perturb = np.zeros(2, dtype=float)
if stack_num > 1:
stack_origin[2] = lb_origin[2] - 0.5 * stack_len + stack_id * lb_len / (stack_num - 1)
stack_perturb[0] = max(stack_origin[2] - 0.5 * stack_len, lb_origin[2] - 0.5 * stack_len)
stack_perturb[1] = min(stack_origin[2] + 0.5 * stack_len, lb_origin[2] + lb_len - 0.5 * stack_len)
else:
stack_origin[2] = lb_origin[2] + 0.5 * (lb_len - stack_len)
stack_perturb[0] = lb_origin[2] - 0.5 * stack_len
stack_perturb[1] = lb_origin[2] + lb_len - 0.5 * stack_len
self.ids.append([d_name, casename, image_fn, image_ld_fn, label_fn, stack_id, stack_size, stack_spacing, stack_origin, stack_perturb, stack_id == stack_num - 1])
def __len__(self):
return len(self.ids)
def __getitem__(self, index):
[d_name, casename, image_fn, image_ld_fn, label_fn, _, stack_size, stack_spacing, base_origin, stack_perturb, eof] = self.ids[index]
stack_origin = base_origin.copy()
if self.perturb:
# stack_len = stack_size[2] * stack_spacing[2]
# stack_origin[2] = stack_origin[2] + (random.random() - 0.5) * stack_len
stack_origin[2] = stack_perturb[0] + random.random() * (stack_perturb[1] - stack_perturb[0])
t = generate_transform(identity=True)
src_image = read_image(fname=image_fn, imtype=self.ImageType)
image = resample(
image = src_image, imtype=self.ImageType,
size=stack_size, spacing=stack_spacing, origin=stack_origin,
transform=t, linear=True, dtype=np.float32)
image['array'] = normalize(image['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])
src_image_ld = read_image(fname=image_ld_fn, imtype=self.ImageType)
src_image_ld.SetOrigin(src_image.GetOrigin())
src_image_ld.SetSpacing(src_image.GetSpacing())
src_image_ld.SetDirection(src_image.GetDirection())
image_ld = resample(
image=src_image_ld, imtype=self.ImageType,
size=stack_size, spacing=stack_spacing, origin=stack_origin,
transform=t, linear=True, dtype=np.float32)
image_ld['array'] = normalize(image_ld['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])
src_label = read_image(fname=label_fn, imtype=self.LabelType)
src_label.SetOrigin(src_image.GetOrigin())
src_label.SetSpacing(src_image.GetSpacing())
src_label.SetDirection(src_image.GetDirection())
label = resample(
image=src_label, imtype=self.LabelType,
size=stack_size, spacing=stack_spacing, origin=stack_origin,
transform=t, linear=False, dtype=np.int64)
tmp_array = np.zeros_like(label['array'])
lmap = self.label_map[d_name]
for key in lmap:
tmp_array[label['array'] == key] = lmap[key]
label['array'] = tmp_array
# if casename == 'case_00002':
# print(casename,'\n')
# print('label array\n')
# print(np.max(label['array']), np.min(label['array']))
# src_label_arr = itk.GetArrayFromImage(src_label)
# print('src_label array\n')
# print(np.max(src_label_arr), np.min(src_label_arr))
label_bin = make_onehot(label['array'], cls=self.cls_num)
label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name])
image_tensor = torch.from_numpy(image['array'])
image_ld_tensor = torch.from_numpy(image_ld['array'])
label_tensor = torch.from_numpy(label_bin)
output = {}
output['data'] = image_tensor
output['data_ld'] = image_ld_tensor
output['label'] = label_tensor
output['label_exist'] = label_exist
output['dataset'] = d_name
output['case'] = casename
output['size'] = image['size']
output['spacing'] = image['spacing']
output['origin'] = image['origin']
output['org_size'] = image['org_size']
output['org_spacing'] = image['org_spacing']
output['org_origin'] = image['org_origin']
output['eof'] = eof
return output