Convolution on non-overlapping blocks of a 3D image then remapping them

Hi I have a 3D MRI image of size (128,128,128) as input to my model. When it enters the model it has the shape (8, 4, 128, 128, 128) which is (Batch, Channels, H, W, D).

I would like to separate the channels and perform a convolution on blocks of (32,32,32) for this (128,128,128) input. Then I wish to take the conv weights and multiply it with the input values to the conv and remap them to a (128,128,128) block.

My current inefficient solution (using many for loops and scikit-image) is below however it takes too long and requires too much memory. What’s the best way to do this?

from skimage.util.shape import view_as_blocks

class LFBlock(nn.Module):
    def __init__(self, input_shape=(128,128,128), kernel_size=(1,1,1), blk_div=4):
        super(LFBlock,self).__init__()      
        
        # Divides the (128,128,128)//4 -> (32,32,32)
        self.block_shape = (input_shape[0]//blk_div, input_shape[1]//blk_div, input_shape[2]//blk_div)
        
        self.num_blocks = (input_shape[0]//self.block_shape[0])*(input_shape[0]//self.block_shape[0])*\
        (input_shape[0]//self.block_shape[0])
        
        conv_list = []
        for n in range(self.num_blocks):
            conv_list.append(nn.Conv3d(1,1, kernel_size=kernel_size, stride=1, padding=0, bias=True))
            
        self.conv1x1s = nn.ModuleList(conv_list)        
                
    def forward(self, lf_in):
        # Batch
        for i in range(lf_in.shape[0]):
            # Modality
            for ch in range(lf_in.shape[1]):
                x_lf = lf_in[i,ch,:]
                lf_blocks = view_as_blocks(x_lf.cpu().numpy(), block_shape=self.block_shape)
                
                # Do Conv3d on each block 
                for x in range(len(lf_blocks)):
                    for y in range(len(lf_blocks)):
                        for z in range(len(lf_blocks)):
                            
                            conv_idx = x*len(lf_blocks) + y*len(lf_blocks) + z
                            
                            # Convolve the block, then multiply with the weight of the block. 
                            tensor_img = torch.from_numpy(lf_blocks[x,y,z])[None, None,:]
                            conv = self.conv1x1s[conv_idx](tensor_img.cuda())
                            
                            # w * x. 
                            # view_as_blocks returns a view so modifications are done in-place

                            lf_blocks[x,y,z] = tensor_img.cpu()*self.conv1x1s[conv_idx].weight.data.cpu()
        
        # Linearly sum the modalities together
        # out = w0*x0 + w1*x1 + w2*x2 + w3*x3 
        out = (lf_in[:,0]+lf_in[:,1]+lf_in[:,2]+lf_in[:,3])[:,None]
        
        return out

Any help is appreciated. Thank you!