RuntimeError: The size of tensor a (6) must match the size of tensor b (96) at non-singleton dimension 0

So the error is occurring on line 678 of your code and it’s most likely because out and x16 aren’t broadcastable. Could you print the size of these Tensors?

Also, you can copy code in by wrapping it in three backticks ```
For example,

x = torch.randn(10,10)

thanks for ur quick response, I could not find the size of these tensors but here are the model.py, train.py, and dataset.py



# model.py

import torch
import torch.nn as nn
#  import torch.optim as optim
import torch.nn.functional as F
#  import argparse
#  import torch.utils.data.sampler as sampler

#  from torch.autograd import Variable

#  import numpy as np
#  from torch.autograd import Variable


class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        y = self.conv(x)
        return y

# Encoding block in U-Net


class enc_block(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_prob=0.0):
        super(enc_block, self).__init__()
        self.conv = double_conv(in_ch, out_ch)
        self.down = nn.MaxPool3d(2)
        self.dropout_prob = dropout_prob
        if dropout_prob > 0:
            self.dropout = nn.Dropout3d(p=dropout_prob)

    def forward(self, x):
        y_conv = self.conv(x)
        y = self.down(y_conv)
        if self.dropout_prob > 0:
            y = self.dropout(y)
        return y, y_conv

# Decoding block in U-Net


class dec_block(nn.Module):
    def __init__(self, in_ch, out_ch, bilinear=True, dropout_prob=0.0):
        super(dec_block, self).__init__()
        self.conv = double_conv(in_ch, out_ch)
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
        else:
            self.up = nn.ConvTranspose3d(out_ch, out_ch, 2, stride=2)
        self.dropout_prob = dropout_prob
        if dropout_prob > 0:
            self.dropout = nn.Dropout3d(p=dropout_prob)

    def forward(self, x):
        y_conv = self.conv(x)
        y = self.up(y_conv)
        if self.dropout_prob > 0:
            y = self.dropout(y)
        return y, y_conv


def concatenate(x1, x2):
    # input is CHW
    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                    diffY // 2, diffY - diffY//2))        
    y = torch.cat([x2, x1], dim=1)
    return y


class softmax(nn.Module):
    def __init__(self, cls_num):
        super(softmax, self).__init__()
        self.softmax = nn.Softmax(dim=1)
        self.cls_num = cls_num

    def forward(self, x):
        y = torch.zeros_like(x)
        for i in range(self.cls_num):
            y[:,i*2:i*2+2] = self.softmax(x[:,i*2:i*2+2])
        return y

class fuse_conv(nn.Module):
    def __init__(self, ch):
        super(fuse_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch, ch, 1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        y = self.conv(x)
        return y

# U-Net (baseline)


class UNet(nn.Module):
    def __init__(self, in_ch, cls_num, base_ch=64):
        super(UNet, self).__init__()
        self.in_ch = in_ch
        self.cls_num = cls_num
        self.base_ch = base_ch

        self.enc1 = enc_block(in_ch, base_ch)
        self.enc2 = enc_block(base_ch, base_ch*2)
        self.enc3 = enc_block(base_ch*2, base_ch*4)
        self.enc4 = enc_block(base_ch*4, base_ch*8)

        self.dec1 = dec_block(base_ch*8, base_ch*8, bilinear=False)
        self.dec2 = dec_block(base_ch*8+base_ch*8, base_ch*4, bilinear=False)
        self.dec3 = dec_block(base_ch*4+base_ch*4, base_ch*2, bilinear=False)
        self.dec4 = dec_block(base_ch*2+base_ch*2, base_ch, bilinear=False)
        self.lastconv = double_conv(base_ch+base_ch, base_ch)

        self.outconv = nn.Conv3d(base_ch, cls_num*2, 1)
        self.softmax = softmax(cls_num)

    def forward(self, x):
        enc1, enc1_conv = self.enc1(x)
        enc2, enc2_conv = self.enc2(enc1)
        enc3, enc3_conv = self.enc3(enc2)
        enc4, enc4_conv = self.enc4(enc3)
        dec1, _ = self.dec1(enc4)
        dec2, _ = self.dec2(concatenate(dec1, enc4_conv))
        dec3, _ = self.dec3(concatenate(dec2, enc3_conv))
        dec4, _ = self.dec4(concatenate(dec3, enc2_conv))
        lastconv = self.lastconv(concatenate(dec4, enc1_conv))
        output = self.outconv(lastconv)
        output = self.softmax(output)

        return output

    def description(self):
        return 'U-Net with {0:d}-ch input for {1:d}-class segmentation (base channel = {2:d})'.format(self.in_ch, self.cls_num, self.base_ch)

# U-Net with three decoding paths sharing one encoding path
# each decoding path predicts a binary mask of one target organ
class UNet2(nn.Module):
    def __init__(self, in_ch, cls_num, base_ch=64):
        super(UNet2, self).__init__()
        self.in_ch = in_ch
        self.cls_num = cls_num
        self.base_ch = base_ch

        self.enc1 = enc_block(in_ch, base_ch)
        self.enc2 = enc_block(base_ch, base_ch*2)
        self.enc3 = enc_block(base_ch*2, base_ch*4)
        self.enc4 = enc_block(base_ch*4, base_ch*8)

        self.dec11 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec12 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec13 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec14 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
        self.lastconv1 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv1 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax1 = nn.Softmax(dim=1)

        self.dec21 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec22 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec23 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec24 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
        self.lastconv2 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv2 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax2 = nn.Softmax(dim=1)

        self.dec31 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec32 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec33 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec34 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)
        self.lastconv3 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv3 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax3 = nn.Softmax(dim=1)

    def forward(self, x):
        enc1, enc1_conv = self.enc1(x)
        enc2, enc2_conv = self.enc2(enc1)
        enc3, enc3_conv = self.enc3(enc2)
        enc4, enc4_conv = self.enc4(enc3)

        dec11, _ = self.dec11(enc4)
        dec12, _ = self.dec12(concatenate(dec11, enc4_conv))
        dec13, _ = self.dec13(concatenate(dec12, enc3_conv))
        dec14, _ = self.dec14(concatenate(dec13, enc2_conv))
        lastconv1 = self.lastconv1(concatenate(dec14, enc1_conv))
        output1 = self.outconv1(lastconv1)
        output1 = self.softmax1(output1)

        dec21, _ = self.dec21(enc4)
        dec22, _ = self.dec22(concatenate(dec21, enc4_conv))
        dec23, _ = self.dec23(concatenate(dec22, enc3_conv))
        dec24, _ = self.dec24(concatenate(dec23, enc2_conv))
        lastconv2 = self.lastconv2(concatenate(dec24, enc1_conv))
        output2 = self.outconv2(lastconv2)
        output2 = self.softmax2(output2)

        dec31, _ = self.dec31(enc4)
        dec32, _ = self.dec32(concatenate(dec31, enc4_conv))
        dec33, _ = self.dec33(concatenate(dec32, enc3_conv))
        dec34, _ = self.dec34(concatenate(dec33, enc2_conv))
        lastconv3 = self.lastconv3(concatenate(dec34, enc1_conv))
        output3 = self.outconv3(lastconv3)
        output3 = self.softmax3(output3)

        output = torch.cat([output1, output2, output3], dim=1)

        return output

    def description(self):
        return 'U-Net with three decoding paths sharing one encoding path [{0:d}-ch input][{1:d}-class seg][base ch={2:d}]'.format(self.in_ch, self.cls_num, self.base_ch)


# model 'UNet2' with intermediate 1x1 conv layers re-weighting features maps from different organ branches (decoding path)


class UNet2_ch1x1(nn.Module):
    def __init__(self, in_ch, cls_num, base_ch=64):
        super(UNet2_ch1x1, self).__init__()
        self.in_ch = in_ch
        self.cls_num = cls_num
        self.base_ch = base_ch

        self.enc1 = enc_block(in_ch, base_ch)
        self.enc2 = enc_block(base_ch, base_ch*2)
        self.enc3 = enc_block(base_ch*2, base_ch*4)
        self.enc4 = enc_block(base_ch*4, base_ch*8)

        self.dec11 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec12 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec13 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec14 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)

        self.dec21 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec22 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec23 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec24 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)

        self.dec31 = dec_block(base_ch*8, base_ch*4, bilinear=False)
        self.dec32 = dec_block(base_ch*4+base_ch*8, base_ch*2, bilinear=False)
        self.dec33 = dec_block(base_ch*2+base_ch*4, base_ch, bilinear=False)
        self.dec34 = dec_block(base_ch+base_ch*2, base_ch//2, bilinear=False)

        self.fuse1 = fuse_conv(base_ch*4*3)
        self.fuse2 = fuse_conv(base_ch*2*3)
        self.fuse3 = fuse_conv(base_ch*3)
        self.fuse4 = fuse_conv((base_ch//2)*3)

        self.lastconv1 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv1 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax1 = nn.Softmax(dim=1)

        self.lastconv2 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv2 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax2 = nn.Softmax(dim=1)

        self.lastconv3 = double_conv(base_ch//2+base_ch, base_ch//2)
        self.outconv3 = nn.Conv3d(base_ch//2, 2, 1)
        self.softmax3 = nn.Softmax(dim=1)

    def forward(self, x):
        enc1, enc1_conv = self.enc1(x)
        enc2, enc2_conv = self.enc2(enc1)
        enc3, enc3_conv = self.enc3(enc2)
        enc4, enc4_conv = self.enc4(enc3)

        dec11, _ = self.dec11(enc4)
        dec21, _ = self.dec21(enc4)
        dec31, _ = self.dec31(enc4)

        f1 = self.fuse1(torch.cat([dec11, dec21, dec31], dim=1))
        [dec11_, dec21_, dec31_] = torch.split(f1, self.base_ch*4, dim=1)

        dec12, _ = self.dec12(concatenate(dec11_, enc4_conv))
        dec22, _ = self.dec22(concatenate(dec21_, enc4_conv))
        dec32, _ = self.dec32(concatenate(dec31_, enc4_conv))

        f2 = self.fuse2(torch.cat([dec12, dec22, dec32], dim=1))
        [dec12_, dec22_, dec32_] = torch.split(f2, self.base_ch*2, dim=1)

        dec13, _ = self.dec13(concatenate(dec12_, enc3_conv))
        dec23, _ = self.dec23(concatenate(dec22_, enc3_conv))
        dec33, _ = self.dec33(concatenate(dec32_, enc3_conv))

        f3 = self.fuse3(torch.cat([dec13, dec23, dec33], dim=1))
        [dec13_, dec23_, dec33_] = torch.split(f3, self.base_ch, dim=1)
        
        dec14, _ = self.dec14(concatenate(dec13_, enc2_conv))
        dec24, _ = self.dec24(concatenate(dec23_, enc2_conv))
        dec34, _ = self.dec34(concatenate(dec33_, enc2_conv))

        f4 = self.fuse4(torch.cat([dec14, dec24, dec34], dim=1))
        [dec14_, dec24_, dec34_] = torch.split(f4, self.base_ch//2, dim=1)
        
        lastconv1 = self.lastconv1(concatenate(dec14_, enc1_conv))
        output1 = self.outconv1(lastconv1)
        output1 = self.softmax1(output1)

        lastconv2 = self.lastconv2(concatenate(dec24_, enc1_conv))
        output2 = self.outconv2(lastconv2)
        output2 = self.softmax2(output2)

        lastconv3 = self.lastconv3(concatenate(dec34_, enc1_conv))
        output3 = self.outconv3(lastconv3)
        output3 = self.softmax3(output3)

        output = torch.cat([output1, output2, output3], dim=1)

        return output

    def description(self):
        return '<UNet2> with intermediate 1x1 conv layers re-weighting features maps from different organ branches (decoding path) [{0:d}-ch input][{1:d}-class seg][base ch={2:d}]'.format(self.in_ch, self.cls_num, self.base_ch)


#################################### TOD Net  #####################################
# modified from https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
class ResnetBlock(nn.Module):
    def __init__(self, dim, padding_type='reflect', norm_layer=nn.BatchNorm3d, use_dropout=False, use_bias=False):
        super(ResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)

    def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad3d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim),
                       nn.ReLU(True)]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad3d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad3d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [nn.Conv3d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
                       norm_layer(dim)]
        return nn.Sequential(*conv_block)

    def forward(self, x):
        out = x + self.conv_block(x)
        return out

'''WGAN discriminator'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.out_shape = 32
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool3d(2)
        )
        ############################################################
        self.linear = nn.Sequential(
            torch.nn.Linear(128 * 4 * self.out_shape ** 2, 1),
        )


    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(-1, 128 * 4 * self.out_shape ** 2)
        x = self.linear(x)
        return x



#################################### WGAN  #####################################
'''WGAN generator'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=3, padding=1),
            nn.InstanceNorm3d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        # torch.nn.init.
        self.conv2 = nn.Sequential(
            nn.Conv3d(32, 64, kernel_size=3, padding=1),
            nn.InstanceNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(64, 128, kernel_size=3, padding=1),
            nn.InstanceNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv3d(128, 256, kernel_size=3, padding=1),
            nn.InstanceNorm3d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.deConv1_1 = nn.Conv3d(256, 128, kernel_size=3, padding=1)
        self.deConv1 = nn.Sequential(
            nn.InstanceNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.deConv2_1 = nn.Conv3d(128, 64, kernel_size=3, padding=1)
        self.deConv2 = nn.Sequential(
            nn.InstanceNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.deConv3_1 = nn.Conv3d(64, 32, kernel_size=3, padding=1)
        self.deConv3 = nn.Sequential(
            nn.InstanceNorm3d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.deConv4_1 = nn.Conv3d(32, 1, kernel_size=3, padding=1)
        self.deConv4 = nn.Tanh() #nn.ReLU()

    def forward(self, input):
        conv1 = self.conv1(input)
        conv2 = self.conv2(conv1)
        conv3 = self.conv3(conv2)
        conv4 = self.conv4(conv3)
        
        x = self.deConv1_1(conv4)
        x = x + conv3

        deConv1 = self.deConv1(x)
        
        x = self.deConv2_1(deConv1)
        x += conv2
        deConv2 = self.deConv2(x)

        x = self.deConv3_1(deConv2)
        x += conv1
        deConv3 = self.deConv3(x)

        x = self.deConv4_1(deConv3)
        x += input
        output = self.deConv4(x)

        return output



'''WGAN discriminator'''

'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.out_shape = 32
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels=1, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm3d(64),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm3d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool3d(2)
        )
        ############################################################
        self.linear = nn.Sequential(
            torch.nn.Linear(128 * 4 * self.out_shape ** 2, 1),
#             nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        # Convolutional layers
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # Flatten and apply sigmoid
        x = x.view(-1, 128 * 4 * self.out_shape ** 2)
        x = self.linear(x)
        return x


   ''' 
    
#################################### VoxResNet  #####################################
import torch
import torch.nn as nn

class softmax(nn.Module):
    def __init__(self, cls_num):
        super(softmax, self).__init__()
        self.softmax = nn.Softmax(dim=1)
        self.cls_num = cls_num

    def forward(self, x):
        y = torch.zeros_like(x)
        for i in range(self.cls_num):
            y[:,i*2:i*2+2] = self.softmax(x[:,i*2:i*2+2])
        return y

class VoxRes(nn.Module):
    def __init__(self, in_channel):
        super(VoxRes, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm3d(in_channel), 
            nn.ReLU(),
            nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1),
            nn.BatchNorm3d(in_channel), 
            nn.ReLU(),
            nn.Conv3d(in_channel, in_channel, kernel_size=3, padding=1)
            )

    def forward(self, x):
        return self.block(x)+x
    
    
class VoxResNet(nn.Module):
    def __init__(self, in_channels, num_class):
        super(VoxResNet, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32), 
            nn.ReLU(),
            nn.Conv3d(32, 32, kernel_size=3, padding=1)
            )
        
        self.conv2 = nn.Sequential(
            nn.BatchNorm3d(32), 
            nn.ReLU(),
            nn.Conv3d(32, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
            VoxRes(64),
            VoxRes(64)
            )

        self.conv3 = nn.Sequential(
            nn.BatchNorm3d(64), 
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
            VoxRes(64),
            VoxRes(64)
            )
        
        self.conv4 = nn.Sequential(
            nn.BatchNorm3d(64), 
            nn.ReLU(),
            nn.Conv3d(64, 64, kernel_size=3, stride=(1, 2, 2), padding=1),
            VoxRes(64),
            VoxRes(64)
            )
        
        self.deconv_c1 = nn.Sequential(
            nn.ConvTranspose3d(32, 32, kernel_size=(1, 1, 1), stride=(1, 1, 1)),
            nn.Conv3d(32, num_class*2, kernel_size=1))
        
        self.deconv_c2 = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Conv3d(64, num_class*2, kernel_size=1))
        
        self.deconv_c3 = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 4, 4)),
            nn.Conv3d(64, num_class*2, kernel_size=1))
        
        self.deconv_c4 = nn.Sequential(
            nn.ConvTranspose3d(64, 64, kernel_size=(1, 8, 8), stride=(1, 8, 8)),
            nn.Conv3d(64, num_class*2, kernel_size=1))
        
#         self.outconv = nn.Conv3d(num_class, num_class, 1)
        self.softmax = softmax(num_class)

    def forward(self, x):
        out1 = self.conv1(x)
        out2 = self.conv2(out1)
        out3 = self.conv3(out2)     
        out4 = self.conv4(out3)
        
        c1 = self.deconv_c1(out1)
        c2 = self.deconv_c2(out2)
        c3 = self.deconv_c3(out3)
        c4 = self.deconv_c4(out4)
        
        return self.softmax(c1+c2+c3+c4)

# missing classes I have added

def ELUCons(elu, nchan):
    if elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)

def passthrough(x, **kwargs):
    return x


def _make_nConv(nchan, depth, elu):
    layers = []
    for _ in range(depth):
        layers.append(LUConv(nchan, elu))
    return nn.Sequential(*layers)



class ContBatchNorm3d(nn.BatchNorm3d):
    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))
        super(ContBatchNorm3d, self)._check_input_dim(input)

    def forward(self, input):
        self._check_input_dim(input)
        return F.batch_norm(
            input, self.running_mean, self.running_var, self.weight, self.bias,
            True, self.momentum, self.eps)
class LUConv(nn.Module):
    def __init__(self, nchan, elu):
        super(LUConv, self).__init__()
        self.relu1 = ELUCons(elu, nchan)
        self.conv1 = nn.Conv3d(nchan, nchan, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(nchan)

    def forward(self, x):
        out = self.relu1(self.bn1(self.conv1(x)))
        return out
class InputTransition(nn.Module):
    def __init__(self, outChans, elu):
        super(InputTransition, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(16)
        self.relu1 = ELUCons(elu, 16) 
    def forward(self, x):
        # do we want a PRELU here as well?
        out = self.bn1(self.conv1(x))
        # split input in to 16 channels
        x16 = torch.cat((x, x, x, x, x, x, x, x,
                         x, x, x, x, x, x, x, x), 0)
        out = self.relu1(torch.add(out, x16))
        return out
  

class DownTransition(nn.Module):
    def __init__(self, inChans, nConvs, elu, dropout=False):
        super(DownTransition, self).__init__()
        outChans = 2*inChans
        self.down_conv = nn.Conv3d(inChans, outChans, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans)
        self.do1 = passthrough
        self.relu1 = ELUCons(elu, outChans)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x):
        down = self.relu1(self.bn1(self.down_conv(x)))
        out = self.do1(down)
        out = self.ops(out)
        out = self.relu2(torch.add(out, down))
        return out

class UpTransition(nn.Module):
    def __init__(self, inChans, outChans, nConvs, elu, dropout=False):
        super(UpTransition, self).__init__()
        self.up_conv = nn.ConvTranspose3d(inChans, outChans // 2, kernel_size=2, stride=2)
        self.bn1 = ContBatchNorm3d(outChans // 2)
        self.do1 = passthrough
        self.do2 = nn.Dropout3d()
        self.relu1 = ELUCons(elu, outChans // 2)
        self.relu2 = ELUCons(elu, outChans)
        if dropout:
            self.do1 = nn.Dropout3d()
        self.ops = _make_nConv(outChans, nConvs, elu)

    def forward(self, x, skipx):
        out = self.do1(x)
        skipxdo = self.do2(skipx)
        out = self.relu1(self.bn1(self.up_conv(out)))
        xcat = torch.cat((out, skipxdo), 1)
        out = self.ops(xcat)
        out = self.relu2(torch.add(out, xcat))
        return out

class OutputTransition(nn.Module):
    def __init__(self, inChans, elu, nll):
        super(OutputTransition, self).__init__()
        self.conv1 = nn.Conv3d(inChans, 2, kernel_size=5, padding=2)
        self.bn1 = ContBatchNorm3d(2)
        self.conv2 = nn.Conv3d(2, 2, kernel_size=1)
        self.relu1 = ELUCons(elu, 2)
        if nll:
            self.softmax = F.log_softmax
        else:
            self.softmax = F.softmax

    def forward(self, x):
        # convolve 32 down to 2 channels
        out = self.relu1(self.bn1(self.conv1(x)))
        out = self.conv2(out)

        # make channels the last axis
        out = out.permute(0, 2, 3, 4, 1).contiguous()
        # flatten
        out = out.view(out.numel() // 2, 2)
        out = self.softmax(out)
        # treat channel 0 as the predicted output
        return out
    
#################################### VNet  #####################################
class VNet(nn.Module):
    # the number of convolutions in each layer corresponds
    # to what is in the actual prototxt, not the intent
    def __init__(self, elu=True, nll=False):
        super(VNet, self).__init__()
        self.in_tr = InputTransition(16, elu)
        self.down_tr32 = DownTransition(16, 1, elu)
        self.down_tr64 = DownTransition(32, 2, elu)
        self.down_tr128 = DownTransition(64, 3, elu, dropout=True)
        self.down_tr256 = DownTransition(128, 2, elu, dropout=True)
        self.up_tr256 = UpTransition(256, 256, 2, elu, dropout=True)
        self.up_tr128 = UpTransition(256, 128, 2, elu, dropout=True)
        self.up_tr64 = UpTransition(128, 64, 1, elu)
        self.up_tr32 = UpTransition(64, 32, 1, elu)
        self.out_tr = OutputTransition(32, elu, nll)
        

    def forward(self, x):
        out16 = self.in_tr(x)
        out32 = self.down_tr32(out16)
        out64 = self.down_tr64(out32)
        out128 = self.down_tr128(out64)
        out256 = self.down_tr256(out128)
        out = self.up_tr256(out256, out128)
        out = self.up_tr128(out, out64)
        out = self.up_tr64(out, out32)
        out = self.up_tr32(out, out16)
        out = self.out_tr(out)
        return out


    






train.py

# !nvidia-smi

# import sys
import os
import numpy as np
# import random
import time
# os.environ['CUDA_VISIBLE_DEVICES'] = "3"
import torch

import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils import data
# torch.cuda.current_device()


# from dataset import *
from dataset import create_folds, DatasetStk
from model import VNet 
# VoxRes, VoxResNet, UNet, UNet2, UNet2_ch1x1, ResUNet, ResnetBlock
from loss import DiceLoss

from utils import init_model, save_model
from metric import eval
# import math

# torch.cuda.empty_cache()     


class WGAN_De_Att:
    def __init__(self, netV=None):
        # initialize networks
        if netV is None:
            self.netV = init_model(VNet())
        else:
            self.netV = netV
        # initialize optimizers
        self.optimizer_V = optim.SGD(self.netV.parameters(), lr=1e-2, momentum=0.99, weight_decay=1e-8)
    
    def scheduler(self, epoch, lr0=0.01):
        if epoch>=80:
            lr = lr0*0.1
            for param_group in self.optimizer_V.param_groups:
                param_group['lr'] = lr
        if epoch>100:
            lr = lr0*0.01    
            for param_group in self.optimizer_V.param_groups:
                param_group['lr'] = lr
    
    def batch_train(self, cfg, dl_train, epoch_id):
        self.netV.train()
        criterion = DiceLoss()

        # for this epoch
        epoch_loss = np.zeros(cfg['cls_num'], dtype=np.float)
        epoch_loss_num = np.zeros(cfg['cls_num'], dtype=np.int64)
        
        for batch_id, batch in enumerate(dl_train):
            image = batch['data']
            label = batch['label']
            flag = batch['label_exist']
            n = len(image)

            image, label = image.cuda(), label.cuda()
            
            print_line = 'Epoch {0:d}/{1:d} (train) --- Progress {2:5.2f}% (+{3:d})'.format(
                    epoch_id+1, cfg['epoch_num'], 100.0 * batch_id * cfg['batch_size'] / len(d_train), n)
#             occupy_mem_allgpus()
            ### UNet loss
            '''---Dice loss---'''
            loss = 0
            cls_loss = np.zeros(cfg['cls_num'], dtype=np.float)
            pred = self.netV(image)
            for c in range(cfg['cls_num']):
                if torch.sum(flag[:,c]) > 0:
                    l = criterion(pred[:,c*2:c*2+2], label[:,c*2:c*2+2], flag[:,c])
                    loss += l
                    cls_loss[c] = l.item()
                    epoch_loss[c] += cls_loss[c]
                    epoch_loss_num[c] += 1
                else:
                    cls_loss[c] = 0

            print_line += ' -- Dice Loss: {0:.4f}'.format(loss.item())
            print(print_line)
            
            self.netV.zero_grad()
            loss.backward()
            self.optimizer_V.step()
            del image, label, pred, loss

        train_loss = np.sum(epoch_loss)/np.sum(epoch_loss_num)
        epoch_loss = epoch_loss / epoch_loss_num
        print_line = 'Epoch {0:d}/{1:d} (train) --- Loss: {2:.6f} ({3:s})\n'.format(
                epoch_id+1, cfg['epoch_num'], train_loss, '/'.join(['%.6f']*len(epoch_loss)) % tuple(epoch_loss))
        print(print_line)
        #torch.cuda.empty_cache() 
      
    def batch_eval(self, cfg, dl_val, epoch_id, loss_fn, eval_test=False):
        self.netV.eval()
        criterion = DiceLoss()
        output_mask = []

        for c in range(cfg['cls_num']):
            output_mask.append(None)
        for batch_id, batch in enumerate(dl_val):
            image = batch['data']
            label = batch['label']
            flag = batch['label_exist']
            n = len(image)
            image = image.cuda()
            label = label.cuda()
            pred = self.netV(image)

#             print('------------------\n')
#             for ii in range(n):
#                 print(torch.max(label[ii,1,:]), torch.min(label[ii,1,:]))

            print_line = 'Epoch {0:d}/{1:d} (val) --- Progress {2:5.2f}% (+{3:d})'.format(
                epoch_id+1, cfg['epoch_num'], 100.0 * batch_id * cfg['batch_size'] / len(d_val), n)
            print(print_line)
#             occupy_mem_allgpus()
            
            for c in range(cfg['cls_num']):
                pred_bin = torch.argmax(pred[:,c*2:c*2+2], dim=1, keepdim=True)
                for i in range(n):
                    if flag[i, c] > 0:
                        mask = pred_bin[i,:].contiguous().cpu().numpy().copy().astype(dtype=np.uint8)
                        mask = np.squeeze(mask)
#                         print('-----------\n')
#                         print(np.max(mask), np.min(mask))
                        mask = resample_array(
                                mask, batch['size'][i].numpy(), batch['spacing'][i].numpy(), batch['origin'][i].numpy(), 
                                batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy())
                        
                        if output_mask[c] is None:
                            output_mask[c] = mask
                        else:
                            output_mask[c] = output_mask[c] + mask

                        if batch['eof'][i]:
                            output_mask[c][output_mask[c] > 0] = 1
                            output2file(
                                output_mask[c], batch['org_size'][i].numpy(), batch['org_spacing'][i].numpy(), batch['org_origin'][i].numpy(), 
                                '{}/{}@{}@{}.nii.gz'.format(val_result_path, batch['dataset'][i], batch['case'][i], c+1))
                            output_mask[c] = None
                del pred_bin
            del image, pred
        
        if eval_test is True:
            dsc, asd, dsc_m, asd_m = eval(
                pd_path=val_result_path, gt_entries=test_fold, label_map=cfg['label_map'], cls_num=cfg['cls_num'], 
                metric_fn='metric_{0:04d}'.format(epoch_id), calc_asd=False)
        else:
            dsc, asd, dsc_m, asd_m = eval(
                pd_path=val_result_path, gt_entries=val_fold, label_map=cfg['label_map'], cls_num=cfg['cls_num'], 
                metric_fn='metric_{0:04d}'.format(epoch_id), calc_asd=False)

        print_line = 'Epoch {0:d}/{1:d} (val) --- DSC {2:.2f} ({3:s})% --- ASD {4:.2f} ({5:s})mm'.format(
            epoch_id+1, cfg['epoch_num'], 
            dsc_m*100.0, '/'.join(['%.2f']*len(dsc[:,0])) % tuple(dsc[:,0]*100.0), 
            asd_m, '/'.join(['%.2f']*len(asd[:,0])) % tuple(asd[:,0]))
        print(print_line)  
        
    def Net_train(self, cfg, dl_train, dl_val, loss_fn, save_dir, if_eval=False):
        #    check
        #    self.batch_eval(cfg, dl_val, 0, loss_fn)
        
        for epoch_id in range(cfg['epoch_num']):
            t0 = time.perf_counter()
            self.scheduler(epoch_id)
            # train
            self.batch_train(cfg, dl_train, epoch_id)
            # evaluation
            if if_eval is True: 
                self.batch_eval(cfg, dl_val, epoch_id, loss_fn)
            # save model
            file_name_V = 'V_epoch{}'.format(epoch_id)+'.pth'
            save_model(self.netV, save_dir, file_name_V)
            # eval
            if (epoch_id+1)%1==0:
                self.batch_eval(cfg, dl_val, epoch_id, loss_fn)

if __name__ == '__main__':
    cfg = {}
    cfg['cls_num'] = 1
    cfg['gpu'] = '3'   # to use multiple gpu: cfg['gpu'] = '0,1,2,3'
    cfg['fold_num'] = 5
    cfg['epoch_num'] = 120
    cfg['batch_size'] = 6
    cfg['lr'] = 0.001
    cfg['Unet_path'] = ' '
    cfg['model_path'] = ' '
    cfg['rs_size'] = [256,256,32]   # resample size: [x, y, z]
    cfg['rs_spacing'] = [1.5,1.5,3.0]   # resample spacing: [x, y, z]. non-positive value means adaptive spacing fit the physical size: rs_size * rs_spacing = origin_size * origin_spacing
    cfg['rs_intensity'] = [-200.0, 200.0]   # rescale intensity from [min, max] to [0, 1].
    cfg['cpu_thread'] = 8  # multi-thread for data loading. zero means single thread.

    # list of dataset names and paths
    cfg['data_path_train'] = [
        [
        # ['KiTS', '/home/41/hh19/data_new2']
    ]
    cfg['data_path_test'] = [        
        ['BTCV', ' ']
    ]
    cfg['label_map'] = {
        'KiTS':{1:1}
        #'BTCV':{2:1}
    }

    # exclude any samples in the form of '[dataset_name, case_name]'
    cfg['exclude_case'] = [
        ['KiTS', 'case_00133'],
        ['KiTS', 'case_00134'],
        ['KiTS', 'case_00135'],
        ['KiTS', 'case_00136'],
        ['KiTS', 'case_00137'],
        ['KiTS', 'case_00138'],
        ['KiTS', 'case_00139'],
        ['KiTS', 'case_00140'],
        ['KiTS', 'case_00141'],
        ['KiTS', 'case_00142'],
        # ['LiTS', 'volume-102']
    ]

    os.environ["CUDA_VISIBLE_DEVICES"] = cfg['gpu']


train_start_time = time.localtime()
time_stamp = time.strftime("%Y%m%d%H%M%S", train_start_time)
# acc_time = 0

# create directory for results storage
store_dir = '{}/model_{}'.format(cfg['model_path'], time_stamp)
os.makedirs(store_dir, exist_ok=True)

best_model_fn = '{}/epoch_{}.pth.tar'.format(store_dir, 1)
loss_fn = '{}/loss.txt'.format(store_dir)
log_fn = '{}/log.txt'.format(store_dir)

val_result_path = '{}/results_val'.format(store_dir)
os.makedirs(val_result_path, exist_ok=True)

test_result_path = '{}/results_test'.format(store_dir)
os.makedirs(test_result_path, exist_ok=True)

# Dataloader
folds, _ = create_folds(data_path=cfg['data_path_train'], fold_num=cfg['fold_num'], exclude_case=cfg['exclude_case'])

'''create training and validation fold'''
train_fold = []
for i in range(cfg['fold_num'] - 2):
    train_fold.extend(folds[i])

# d_train = Dataset(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_train = DatasetStk(train_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=True)
dl_train = data.DataLoader(dataset=d_train, batch_size=cfg['batch_size'], shuffle=True, pin_memory=True, drop_last=True, num_workers=cfg['cpu_thread'])

# create validaion fold
val_fold = folds[cfg['fold_num']-2]
# d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_val = DatasetStk(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=False)
dl_val = data.DataLoader(dataset=d_val, batch_size=cfg['batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])

# create test fold
test_fold = folds[cfg['fold_num']-1]
#d_val = Dataset(val_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'])
d_test = DatasetStk(test_fold, rs_size=cfg['rs_size'], rs_spacing=cfg['rs_spacing'], rs_intensity=cfg['rs_intensity'], label_map=cfg['label_map'], cls_num=cfg['cls_num'], perturb=False)
dl_test = data.DataLoader(dataset=d_test, batch_size=cfg['batch_size'], shuffle=False, pin_memory=True, drop_last=False, num_workers=cfg['cpu_thread'])

'''--- train downstream tasks ---'''
Solver = WGAN_De_Att()
Solver.Net_train(cfg, dl_train, dl_val, loss_fn, store_dir)


'''--- eval downstream tasks ---'''
ii = 119
file_path = ' '
model_path = os.path.join(file_path,'V_epoch'+str(ii)+'.pth')
model = VNet()
model.cuda()
model.load_state_dict(torch.load(model_path))
netU = nn.DataParallel(module=model)
netU.eval()
Solver = WGAN_De_Att(netU)
Solver.batch_eval(cfg=cfg, 
                  dl_val=dl_test, 
                  epoch_id=0, 
                  loss_fn=loss_fn, 
                  eval_test=True)

dataset.py

import os

import sys

import torch

from torch.utils import data

import itk

import numpy as np

import random

import SimpleITK as sitk

def read_image(fname, imtype):

    reader = itk.ImageFileReader[imtype].New()

    reader.SetFileName(fname)

    reader.Update()

    image = reader.GetOutput()

    return image

def scan_path(d_name, d_path):

    entries = []

    if d_name == 'LiTS':

        for f in os.listdir(d_path):

            if f.startswith('volume-') and f.endswith('.mha'):

                id = int(f.split('.mha')[0].split('volume-')[1])

                if os.path.isfile('{}/segmentation-{}.mha'.format(d_path, id)):

                    case_name = 'volume-{}'.format(id)

                    image_name = '{}/volume-{}.mha'.format(d_path, id)

                    label_name = '{}/segmentation-{}.mha'.format(d_path, id)

                    entries.append([d_name, case_name, image_name, label_name])

    elif d_name == 'KiTS':

        for case_name in os.listdir(d_path):

            image_name = '{}/{}/imaging.nii.gz'.format(d_path, case_name)

            image_ld_name = '{}/{}/imaging_ld.nii.gz'.format(d_path, case_name)

            label_name = '{}/{}/segmentation.nii.gz'.format(d_path, case_name)

            if os.path.isfile(image_name) and os.path.isfile(label_name) and os.path.isfile(image_ld_name):

                entries.append([d_name, case_name, image_name, image_ld_name, label_name])

    elif d_name == 'BTCV':

        for f in os.listdir(d_path):

            if f.startswith('volume-'):

                id = int(f.split('.nii')[0].split('volume-')[1])

                if os.path.isfile('{}/segmentation-{}.nii.gz'.format(d_path, id)):

                    case_name = 'volume-{}'.format(id)

                    image_name = '{}/volume-{}.nii.gz'.format(d_path, id)

                    label_name = '{}/segmentation-{}.nii.gz'.format(d_path, id)

                    entries.append([d_name, case_name, image_name, label_name])

    elif d_name == 'spleen':

        for f in os.listdir('{}/imagesTr'.format(d_path)):

            if f.startswith('spleen_'):

                id = int(f.split('.nii.gz')[0].split('spleen_')[1])

                if os.path.isfile('{}/labelsTr/{}'.format(d_path, f)):

                    case_name = 'spleen_{}'.format(id)

                    image_name = '{}/imagesTr/{}'.format(d_path, f)

                    label_name = '{}/labelsTr/{}'.format(d_path, f)

                    entries.append([d_name, case_name, image_name, label_name])

    return entries

def create_folds(data_path, fold_num, exclude_case):

    fold_file_name = '{0:s}/CV_{1:d}-fold.txt'.format(sys.path[0], fold_num)

    folds = {}

    if os.path.exists(fold_file_name):

        with open(fold_file_name, 'r') as fold_file:

            strlines = fold_file.readlines()

            for strline in strlines:

                strline = strline.rstrip('\n')

                params = strline.split()

                fold_id = int(params[0])

                if fold_id not in folds:

                    folds[fold_id] = []

                folds[fold_id].append([params[1], params[2], params[3], params[5], params[4]])

    else:

        entries = []

        for [d_name, d_path] in data_path:             

            entries.extend(scan_path(d_name, d_path))

        for e in entries:

            if e[0:2] in exclude_case:

                entries.remove(e)

        case_num = len(entries)

        fold_size = int(case_num / fold_num)

        random.shuffle(entries)

        for fold_id in range(fold_num - 1):

            folds[fold_id] = entries[fold_id * fold_size:(fold_id + 1) * fold_size]

        folds[fold_num - 1] = entries[(fold_num - 1) * fold_size:]

        

        with open(fold_file_name, 'w') as fold_file:

            for fold_id in range(fold_num):

                print(fold_id)

                for [d_name, casename, image_fn, image_ld_fn, label_fn] in folds[fold_id]:

                    print(d_name, casename, image_fn, image_ld_fn, label_fn)

                    fold_file.write('{0:d} {1:s} {2:s} {3:s} {4:s} {5:s}\n'.format(fold_id, d_name, casename, image_fn, image_ld_fn, label_fn))

                    

    folds_size = [len(x) for x in folds.values()]

    return folds, folds_size

def normalize(x, min, max):

    factor = 1.0 / (max - min)

    x[x < min] = min

    x[x > max] = max

    x = (x - min) * factor

    return x

def generate_transform(identity):

    if identity:

        t = itk.IdentityTransform[itk.D, 3].New()

    else:

        min_rotate = -0.05   # [rad]

        max_rotate = 0.05   # [rad]

        min_offset = -5.0   # [mm]

        max_offset = 5.0    # [mm]

        t = itk.Euler3DTransform[itk.D].New()

        euler_parameters = t.GetParameters()

        euler_parameters = itk.OptimizerParameters[itk.D](t.GetNumberOfParameters())

        euler_parameters[0] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate

        euler_parameters[1] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate

        euler_parameters[2] = min_rotate + random.random() * (max_rotate - min_rotate) # rotate

        euler_parameters[3] = min_offset + random.random() * (max_offset - min_offset) # tranlate

        euler_parameters[4] = min_offset + random.random() * (max_offset - min_offset) # tranlate

        euler_parameters[5] = min_offset + random.random() * (max_offset - min_offset) # tranlate

        t.SetParameters(euler_parameters)

    return t

def resample(image, imtype, size, spacing, origin, transform, linear, dtype):

    o_origin = image.GetOrigin()

    o_spacing = image.GetSpacing()

    o_size = image.GetBufferedRegion().GetSize()

    output = {}

    output['org_size'] = np.array(o_size, dtype=int)

    output['org_spacing'] = np.array(o_spacing, dtype=float)

    output['org_origin'] = np.array(o_origin, dtype=float)

    

    if origin is None:   # if no origin point specified, center align the resampled image with the original image

        new_size = np.zeros(3, dtype=int)

        new_spacing = np.zeros(3, dtype=float)

        new_origin = np.zeros(3, dtype=float)

        for i in range(3):

            new_size[i] = size[i]

            if spacing[i] > 0:

                new_spacing[i] = spacing[i]

                new_origin[i] = o_origin[i] + o_size[i]*o_spacing[i]*0.5 - size[i]*spacing[i]*0.5

            else:

                new_spacing[i] = o_size[i] * o_spacing[i] / size[i]

                new_origin[i] = o_origin[i]

    else:

        new_size = np.array(size, dtype=int)

        new_spacing = np.array(spacing, dtype=float)

        new_origin = np.array(origin, dtype=float)

    output['size'] = new_size

    output['spacing'] = new_spacing

    output['origin'] = new_origin

    resampler = itk.ResampleImageFilter[imtype, imtype].New()

    resampler.SetInput(image)

    resampler.SetSize((int(new_size[0]), int(new_size[1]), int(new_size[2])))

    resampler.SetOutputSpacing((float(new_spacing[0]), float(new_spacing[1]), float(new_spacing[2])))

    resampler.SetOutputOrigin((float(new_origin[0]), float(new_origin[1]), float(new_origin[2])))

    resampler.SetTransform(transform)

    if linear:

        resampler.SetInterpolator(itk.LinearInterpolateImageFunction[imtype, itk.D].New())

    else:

        resampler.SetInterpolator(itk.NearestNeighborInterpolateImageFunction[imtype, itk.D].New())

    resampler.SetDefaultPixelValue(int(np.min(itk.GetArrayFromImage(image))))

    resampler.Update()

    rs_image = resampler.GetOutput()

    image_array = itk.GetArrayFromImage(rs_image)

    image_array = image_array[np.newaxis, :].astype(dtype)

    output['array'] = image_array

    return output

def make_onehot(input, cls):

    oh = np.repeat(np.zeros_like(input), cls * 2, axis=0)

    for i in range(cls):

        tmp = np.zeros_like(input)

        tmp[input == i + 1] = 1

        oh[i * 2 + 0 , :] = 1 - tmp

        oh[i * 2 + 1 , :] = tmp

    return oh

def make_flag(cls, labelmap):

    flag = np.zeros([cls, 1])

    for key in labelmap:

        flag[labelmap[key] - 1 , 0] = 1

    return flag

def image2file(image, imtype, fname):

    writer = itk.ImageFileWriter[imtype].New()

    writer.SetInput(image)

    writer.SetFileName(fname)

    writer.Update()

def array2file(array, size, origin, spacing, imtype, fname):    

    image = itk.GetImageFromArray(array.reshape([size[2], size[1], size[0]]))

    image.SetSpacing((spacing[0], spacing[1], spacing[2]))

    image.SetOrigin((origin[0], origin[1], origin[2]))

    image2file(image, imtype=imtype, fname=fname)

# dataset of 3D image volume

# 3D volumes are resampled from and center-aligned with the original images

class Dataset(data.Dataset):

    def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num):

        self.ImageType = itk.Image[itk.SS, 3]

        self.LabelType = itk.Image[itk.UC, 3]

        self.ids = ids

        self.rs_size = rs_size

        self.rs_spacing = rs_spacing

        self.rs_intensity = rs_intensity

        self.label_map = label_map

        self.cls_num = cls_num

    

    def __len__(self):

        return len(self.ids)

    def __getitem__(self, index):

        [d_name, casename, image_fn, image_ld_fn, label_fn] = self.ids[index]

        t = generate_transform(identity=True)

        src_image = read_image(fname=image_fn, imtype=self.ImageType)

        image = resample(

            image=src_image, imtype=self.ImageType, 

            size=self.rs_size, spacing=self.rs_spacing, origin=None, 

            transform=t, linear=True, dtype=np.float32)

        image['array'] = normalize(image['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])

        

        src_image_ld = read_image(fname=image_ld_fn, imtype=self.ImageType)

        image_ld = resample(

            image=src_image_ld, imtype=self.ImageType, 

            size=self.rs_size, spacing=self.rs_spacing, origin=None, 

            transform=t, linear=True, dtype=np.float32)

        image_ld['array'] = normalize(image_ld['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])

        

        src_label = read_image(fname=label_fn, imtype=self.LabelType)

        label = resample(

            image=src_label, imtype=self.LabelType, 

            size=self.rs_size, spacing=self.rs_spacing, origin=None, 

            transform=t, linear=False, dtype=np.int64)

        tmp_array = np.zeros_like(label['array'])

        lmap = self.label_map[d_name]

        for key in lmap:

            tmp_array[label['array'] == key] = lmap[key]

        label['array'] = tmp_array

        label_bin = make_onehot(label['array'], cls=self.cls_num)

        label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name])

        image_tensor = torch.from_numpy(image['array'])

        image_ld_tensor = torch.from_numpy(image_ld['array'])

        label_tensor = torch.from_numpy(label_bin)

        output = {}

        output['data'] = image_tensor

        output['data_ld'] = image_ld_tensor

        output['label'] = label_tensor

        output['label_exist'] = label_exist

        output['dataset'] = d_name

        output['case'] = casename

        output['size'] = image['size']

        output['spacing'] = image['spacing']

        output['origin'] = image['origin']

        output['org_size'] = image['org_size']

        output['org_spacing'] = image['org_spacing']

        output['org_origin'] = image['org_origin']

        output['eof'] = True

        return output

# dataset of image stacks (short-length 3D image volume)

# each image is resampled as a series adjacent image stacks

# the image stacks cover the whole range of image length

class DatasetStk(data.Dataset):

    def __init__(self, ids, rs_size, rs_spacing, rs_intensity, label_map, cls_num, perturb):

        self.ImageType = itk.Image[itk.SS, 3]

        self.LabelType = itk.Image[itk.UC, 3]

        self.ids = []

        self.rs_size = rs_size

        self.rs_spacing = rs_spacing

        self.rs_intensity = rs_intensity

        self.label_map = label_map

        self.cls_num = cls_num

        self.perturb = perturb

        for i, [d_name, casename, image_fn, image_ld_fn, label_fn] in enumerate(ids):

            print('Preparing image stacks ({}/{}) ...'.format(i, len(ids)))

            reader = sitk.ImageFileReader()

            reader.SetFileName(image_fn)

            reader.ReadImageInformation()

            size = reader.GetSize()

            spacing = reader.GetSpacing()

            origin = reader.GetOrigin()

            stack_len = rs_size[2] * rs_spacing[2]

            '''

            image_len = size[2] * spacing[2]

            stack_num = int(image_len / stack_len) + 1

            for stack_id in range(stack_num):

                stack_size = np.array(rs_size, dtype=int)

                stack_spacing = np.array(rs_spacing, dtype=float)

                stack_origin = np.zeros(3, dtype=float)

                stack_origin[0] = origin[0] + 0.5 * size[0] * spacing[0] - 0.5 * rs_size[0] * rs_spacing[0]

                stack_origin[1] = origin[1] + 0.5 * size[1] * spacing[1] - 0.5 * rs_size[1] * rs_spacing[1]

                #stack_origin[2] = origin[2] - 0.5 * (stack_num * stack_len - image_len) + stack_id * stack_len

                stack_perturb = np.zeros(2, dtype=float)

                if stack_num > 1:

                    stack_origin[2] = origin[2] + stack_id * (image_len - stack_len) / (stack_num - 1)

                    stack_perturb[0] = max(stack_origin[2] - 0.5 * stack_len, origin[2])

                    stack_perturb[1] = min(stack_origin[2] + 0.5 * stack_len, origin[2] + image_len - stack_len)

                else:

                    stack_origin[2] = origin[2] + 0.5 * (image_len - stack_len)

                    stack_perturb[0] = stack_origin[2]

                    stack_perturb[1] = stack_origin[2]

                self.ids.append([d_name, casename, image_fn, label_fn, stack_id, stack_size, stack_spacing, stack_origin, stack_perturb, stack_id == stack_num-1])

            '''

            lb_reader = sitk.ImageFileReader()

            lb_reader.SetFileName(label_fn)             

            lb_volume = lb_reader.Execute()

            lb_array = sitk.GetArrayFromImage(lb_volume)

            tmp_array = np.zeros_like(lb_array)

            lmap = self.label_map[d_name]

            for key in lmap:

                tmp_array[lb_array == key] = lmap[key]

            lb_array = tmp_array

            nz_ind = np.nonzero(lb_array > 0)

            lb_size = np.zeros(3, dtype=np.float)

            lb_origin = np.zeros(3, dtype=np.float)

            #for i in range(3):

            #    lb_size[i] = (np.max(nz_ind[2-i]) - np.min(nz_ind[2-i]) + 1) * spacing[i]

            #    lb_origin[i] = origin[i] + np.min(nz_ind[2-i]) * spacing[i]

            lb_size[2] = (np.max(nz_ind[0]) - np.min(nz_ind[0]) + 1) * spacing[2]

            lb_origin[2] = origin[2] + np.min(nz_ind[0]) * spacing[2]

            lb_len = lb_size[2]

                   

            if lb_len > stack_len:

                stack_num = int((lb_len + stack_len) / stack_len) + 1

            else:

                stack_num = 1

            for stack_id in range(stack_num):

                stack_size = np.array(rs_size, dtype=int)

                stack_spacing = np.array(rs_spacing, dtype=float)

                stack_origin = np.zeros(3, dtype=float)

                stack_origin[0] = origin[0] + 0.5 * size[0] * spacing[0] - 0.5 * rs_size[0] * rs_spacing[0]

                stack_origin[1] = origin[1] + 0.5 * size[1] * spacing[1] - 0.5 * rs_size[1] * rs_spacing[1]

                #  stack_origin[2] = origin[2] - 0.5 * (stack_num * stack_len - image_len) + stack_id * stack_len

                stack_perturb = np.zeros(2, dtype=float)

                if stack_num > 1:

                    stack_origin[2] = lb_origin[2] - 0.5 * stack_len + stack_id * lb_len / (stack_num - 1)

                    stack_perturb[0] = max(stack_origin[2] - 0.5 * stack_len, lb_origin[2] - 0.5 * stack_len)

                    stack_perturb[1] = min(stack_origin[2] + 0.5 * stack_len, lb_origin[2] + lb_len - 0.5 * stack_len)

                else:

                    stack_origin[2] = lb_origin[2] + 0.5 * (lb_len - stack_len)

                    stack_perturb[0] = lb_origin[2] - 0.5 * stack_len

                    stack_perturb[1] = lb_origin[2] + lb_len - 0.5 * stack_len

                self.ids.append([d_name, casename, image_fn, image_ld_fn, label_fn, stack_id, stack_size, stack_spacing, stack_origin, stack_perturb, stack_id == stack_num - 1])

            

    def __len__(self):

        return len(self.ids)

    def __getitem__(self, index):

        [d_name, casename, image_fn, image_ld_fn, label_fn, _, stack_size, stack_spacing, base_origin, stack_perturb, eof] = self.ids[index]

        stack_origin = base_origin.copy()

        if self.perturb:

            #  stack_len = stack_size[2] * stack_spacing[2]

            #  stack_origin[2] = stack_origin[2] + (random.random() - 0.5) * stack_len

            stack_origin[2] = stack_perturb[0] + random.random() * (stack_perturb[1] - stack_perturb[0])

        t = generate_transform(identity=True)

        src_image = read_image(fname=image_fn, imtype=self.ImageType)

        image = resample(

                    image = src_image, imtype=self.ImageType, 

                    size=stack_size, spacing=stack_spacing, origin=stack_origin, 

                    transform=t, linear=True, dtype=np.float32)        

        image['array'] = normalize(image['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])    

        src_image_ld = read_image(fname=image_ld_fn, imtype=self.ImageType)  

        src_image_ld.SetOrigin(src_image.GetOrigin())

        src_image_ld.SetSpacing(src_image.GetSpacing())

        src_image_ld.SetDirection(src_image.GetDirection())

        image_ld = resample(

            image=src_image_ld, imtype=self.ImageType, 

            size=stack_size, spacing=stack_spacing, origin=stack_origin, 

            transform=t, linear=True, dtype=np.float32)        

        image_ld['array'] = normalize(image_ld['array'], min=self.rs_intensity[0], max=self.rs_intensity[1])

        

        src_label = read_image(fname=label_fn, imtype=self.LabelType)

        

        src_label.SetOrigin(src_image.GetOrigin())

        src_label.SetSpacing(src_image.GetSpacing())

        src_label.SetDirection(src_image.GetDirection())

        label = resample(

            image=src_label, imtype=self.LabelType, 

            size=stack_size, spacing=stack_spacing, origin=stack_origin, 

            transform=t, linear=False, dtype=np.int64)

        tmp_array = np.zeros_like(label['array'])

        lmap = self.label_map[d_name]

        for key in lmap:

            tmp_array[label['array'] == key] = lmap[key]

        label['array'] = tmp_array

#         if casename == 'case_00002':

#         print(casename,'\n')

#         print('label array\n')

#         print(np.max(label['array']), np.min(label['array']))

#         src_label_arr = itk.GetArrayFromImage(src_label)

#         print('src_label array\n')

#         print(np.max(src_label_arr), np.min(src_label_arr))

   

        label_bin = make_onehot(label['array'], cls=self.cls_num)

        label_exist = make_flag(cls=self.cls_num, labelmap=self.label_map[d_name])

        image_tensor = torch.from_numpy(image['array'])

        image_ld_tensor = torch.from_numpy(image_ld['array'])

        label_tensor = torch.from_numpy(label_bin)

        output = {}

        output['data'] = image_tensor

        output['data_ld'] = image_ld_tensor

        output['label'] = label_tensor

        output['label_exist'] = label_exist

        output['dataset'] = d_name

        output['case'] = casename

        output['size'] = image['size']

        output['spacing'] = image['spacing']

        output['origin'] = image['origin']

        output['org_size'] = image['org_size']

        output['org_spacing'] = image['org_spacing']

        output['org_origin'] = image['org_origin']

        output['eof'] = eof

        return output