Hi I have a 3D MRI image of size (128,128,128) as input to my model. When it enters the model it has the shape (8, 4, 128, 128, 128) which is (Batch, Channels, H, W, D).
I would like to separate the channels and perform a convolution on blocks of (32,32,32) for this (128,128,128) input. Then I wish to take the conv weights and multiply it with the input values to the conv and remap them to a (128,128,128) block.
My current inefficient solution (using many for loops and scikit-image) is below however it takes too long and requires too much memory. What’s the best way to do this?
from skimage.util.shape import view_as_blocks class LFBlock(nn.Module): def __init__(self, input_shape=(128,128,128), kernel_size=(1,1,1), blk_div=4): super(LFBlock,self).__init__() # Divides the (128,128,128)//4 -> (32,32,32) self.block_shape = (input_shape//blk_div, input_shape//blk_div, input_shape//blk_div) self.num_blocks = (input_shape//self.block_shape)*(input_shape//self.block_shape)*\ (input_shape//self.block_shape) conv_list =  for n in range(self.num_blocks): conv_list.append(nn.Conv3d(1,1, kernel_size=kernel_size, stride=1, padding=0, bias=True)) self.conv1x1s = nn.ModuleList(conv_list) def forward(self, lf_in): # Batch for i in range(lf_in.shape): # Modality for ch in range(lf_in.shape): x_lf = lf_in[i,ch,:] lf_blocks = view_as_blocks(x_lf.cpu().numpy(), block_shape=self.block_shape) # Do Conv3d on each block for x in range(len(lf_blocks)): for y in range(len(lf_blocks)): for z in range(len(lf_blocks)): conv_idx = x*len(lf_blocks) + y*len(lf_blocks) + z # Convolve the block, then multiply with the weight of the block. tensor_img = torch.from_numpy(lf_blocks[x,y,z])[None, None,:] conv = self.conv1x1s[conv_idx](tensor_img.cuda()) # w * x. # view_as_blocks returns a view so modifications are done in-place lf_blocks[x,y,z] = tensor_img.cpu()*self.conv1x1s[conv_idx].weight.data.cpu() # Linearly sum the modalities together # out = w0*x0 + w1*x1 + w2*x2 + w3*x3 out = (lf_in[:,0]+lf_in[:,1]+lf_in[:,2]+lf_in[:,3])[:,None] return out
Any help is appreciated. Thank you!