Problem:
I am trying to predict 512 vectors from MR images…Actual MR image size is (172,220,156). The problem in my work is I can not do patch wise training since each brain volume has correlations with other regions…in terms of electromagnetic fields. So I have to feed the whole image to the network. Since feeding the large size image (almost 6 million voxels) is computationally expensive(generates cuda memory error in my GTx 1080Ti GPU) and for data augmentation, we subsampled the main images into 8 lower resolution images.
For every 8 images, there will be same target of a vector of size (1,512). First 256 are real values and last 256 are imaginary ones. Please take a look at the target below:
I am using ResNet3D architecture:
ResNet(
(strConv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
(strConv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
(strConv3): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
(conv_block1_32): ConvBlock(
(conv1): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_block32_64): residualUnit(
(conv1): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(convX): Conv3d(32, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(bnX): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_block64_128): residualUnit(
(conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn2): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(convX): Conv3d(64, 128, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(bnX): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_block128_256): residualUnit(
(conv1): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn2): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(convX): Conv3d(128, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(bnX): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(conv_block256_512): residualUnit(
(conv1): Conv3d(256, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn1): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv3d(512, 512, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
(bn2): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(convX): Conv3d(256, 512, kernel_size=(1, 1, 1), stride=(1, 1, 1))
(bnX): BatchNorm3d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(fc1): Linear(in_features=394240, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=512, bias=True)
)
convX
layers are residual skip connections. And strConv
layers are strided convolutional layers to replace Maxpool
. After each Conv+BatchNorm
ReLU
activation laters were used which do not show up above.
Also using weight initialization as below:
def weights_init(m):
if isinstance(m, nn.Conv3d):
torch.nn.init.xavier_uniform_(m.weight.data, init.calculate_gain('relu'))
m.bias.data.fill_(0)
# torch.nn.init.xavier_uniform_(m.bias.data)
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.normal_(mean=1.0, std=0.02)
m.bias.data.fill_(0)
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0.0, 0.02)
m.bias.data.fill_(0)
I trained the model with 144 subsampled images (86,110,78 size) for both 60 and 100 epochs. Adam optimizer was used. Please see the hyperparams:
Criterion:MSELoss()
Optimizer:Adam (
Parameter Group 0
amsgrad: False
betas: (0.9, 0.999)
eps: 1e-08
lr: 0.0001
weight_decay: 0
)
The outputs are not so great.
Is there anything I am doing wrong?
Train loss look like this:
Anything I am missing bigtime??