Hi I have a 3D MRI image of size (128,128,128) as input to my model. When it enters the model it has the shape (8, 4, 128, 128, 128) which is (Batch, Channels, H, W, D).
I would like to separate the channels and perform a convolution on blocks of (32,32,32) for this (128,128,128) input. Then I wish to take the conv weights and multiply it with the input values to the conv and remap them to a (128,128,128) block.
My current inefficient solution (using many for loops and scikit-image) is below however it takes too long and requires too much memory. What’s the best way to do this?
from skimage.util.shape import view_as_blocks
class LFBlock(nn.Module):
def __init__(self, input_shape=(128,128,128), kernel_size=(1,1,1), blk_div=4):
super(LFBlock,self).__init__()
# Divides the (128,128,128)//4 -> (32,32,32)
self.block_shape = (input_shape[0]//blk_div, input_shape[1]//blk_div, input_shape[2]//blk_div)
self.num_blocks = (input_shape[0]//self.block_shape[0])*(input_shape[0]//self.block_shape[0])*\
(input_shape[0]//self.block_shape[0])
conv_list = []
for n in range(self.num_blocks):
conv_list.append(nn.Conv3d(1,1, kernel_size=kernel_size, stride=1, padding=0, bias=True))
self.conv1x1s = nn.ModuleList(conv_list)
def forward(self, lf_in):
# Batch
for i in range(lf_in.shape[0]):
# Modality
for ch in range(lf_in.shape[1]):
x_lf = lf_in[i,ch,:]
lf_blocks = view_as_blocks(x_lf.cpu().numpy(), block_shape=self.block_shape)
# Do Conv3d on each block
for x in range(len(lf_blocks)):
for y in range(len(lf_blocks)):
for z in range(len(lf_blocks)):
conv_idx = x*len(lf_blocks) + y*len(lf_blocks) + z
# Convolve the block, then multiply with the weight of the block.
tensor_img = torch.from_numpy(lf_blocks[x,y,z])[None, None,:]
conv = self.conv1x1s[conv_idx](tensor_img.cuda())
# w * x.
# view_as_blocks returns a view so modifications are done in-place
lf_blocks[x,y,z] = tensor_img.cpu()*self.conv1x1s[conv_idx].weight.data.cpu()
# Linearly sum the modalities together
# out = w0*x0 + w1*x1 + w2*x2 + w3*x3
out = (lf_in[:,0]+lf_in[:,1]+lf_in[:,2]+lf_in[:,3])[:,None]
return out
Any help is appreciated. Thank you!