Regression problem for 3D MRI image

Problem:
I am trying to predict 512 vectors from MR images…Actual MR image size is (172,220,156). The problem in my work is I can not do patch wise training since each brain volume has correlations with other regions…in terms of electromagnetic fields. So I have to feed the whole image to the network. Since feeding the large size image (almost 6 million voxels) is computationally expensive(generates cuda memory error in my GTx 1080Ti GPU) and for data augmentation, we subsampled the main images into 8 lower resolution images.

For every 8 images, there will be same target of a vector of size (1,512). First 256 are real values and last 256 are imaginary ones. Please take a look at the target below:

target

I am using ResNet3D architecture:

ResNet(
  (strConv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (strConv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (strConv3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
  (conv_block1_32): ConvBlock(
    (conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block32_64): residualUnit(
    (conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block64_128): residualUnit(
    (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block128_256): residualUnit(
    (conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_block256_512): residualUnit(
    (conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convX): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (bnX): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (fc1): Linear(in_features=394240, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=512, bias=True)
)

convX layers are residual skip connections. And strConv layers are strided convolutional layers to replace Maxpool. After each Conv+BatchNorm ReLU activation laters were used which do not show up above.

Also using weight initialization as below:

def weights_init(m):
    if isinstance(m, nn.Conv3d):
        torch.nn.init.xavier_uniform_(m.weight.data, init.calculate_gain('relu'))
        m.bias.data.fill_(0)
        # torch.nn.init.xavier_uniform_(m.bias.data)
    elif isinstance(m, nn.BatchNorm3d):
        m.weight.data.normal_(mean=1.0, std=0.02)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0.0, 0.02)
        m.bias.data.fill_(0)

I trained the model with 144 subsampled images (86,110,78 size) for both 60 and 100 epochs. Adam optimizer was used. Please see the hyperparams:

Criterion:MSELoss()
Optimizer:Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0001
    weight_decay: 0
)

The outputs are not so great.

image

Is there anything I am doing wrong?

Train loss look like this:
image

Anything I am missing bigtime??