Regression problem for 3D MRI image

Problem:
I am trying to predict 512 vectors from MR images…Actual MR image size is (172,220,156). The problem in my work is I can not do patch wise training since each brain volume has correlations with other regions…in terms of electromagnetic fields. So I have to feed the whole image to the network. Since feeding the large size image (almost 6 million voxels) is computationally expensive(generates cuda memory error in my GTx 1080Ti GPU) and for data augmentation, we subsampled the main images into 8 lower resolution images.

For every 8 images, there will be same target of a vector of size (1,512). First 256 are real values and last 256 are imaginary ones. Please take a look at the target below:

target

I am using ResNet3D architecture:

ResNet(
  (strConv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (strConv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (strConv3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (conv_block1_32): ConvBlock(
    (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block32_64): residualUnit(
    (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block64_128): residualUnit(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block128_256): residualUnit(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block256_512): residualUnit(
    (conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=394240, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
)

convX layers are residual skip connections. And strConv layers are strided convolutional layers to replace Maxpool. After each Conv+BatchNorm ReLU activation laters were used which do not show up above.

Also using weight initialization as below:

def weights_init(m):
    if isinstance(m, nn.Conv3d):
        torch.nn.init.xavier_uniform_(m.weight.data, init.calculate_gain('relu'))
        m.bias.data.fill_(0)
        # torch.nn.init.xavier_uniform_(m.bias.data)
    elif isinstance(m, nn.BatchNorm3d):
        m.weight.data.normal_(mean=1.0, std=0.02)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)

I trained the model with 144 subsampled images (86,110,78 size) for both 60 and 100 epochs. Adam optimizer was used. Please see the hyperparams:

Criterion:MSELoss()
Optimizer:Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)

The outputs are not so great.

image

Is there anything I am doing wrong?

Train loss look like this:
image

Anything I am missing bigtime??

1 Like

Hi @banikr,

I’m trying to solve a problem that has quite similar structure to yours in different applications. And I’m also having quite similar results like yours… mine is not working okay neither.
I’m digging internet and found this thread. Did you make any progress? Any tips you’d like to share?

Thanks.

1 Like