How to split tensors with overlap and average overlapping values?

Hi. I am working on a medical image segmentation project, and I have a model that takes inputs of size 128x32x128 (X, Z, Y). However, my inputs are not all evenly divisible by these values, and not all inputs are the same size. For example, I have inputs that are divisible (512x96x512) and many that are not (480x68x480). How can I can split these inputs into overlapping tensors to feed into the network, and then take the outputs and reconstruct them into a tensor through averaging overlapping output values that would look as if I fed the entire image into the network? The output has shape. (B,C,X,Z,Y)

For clarification, the inputs are gray scale so they have 1 channel and are of shape (1, W, D, H) and the network takes in batches of these inputs and outputs a tensor of size (B, 10, W, D, H) where 10 is the number of classes (10). I have been looking into nn.Fold and nn.Unfold but I noticed that they only work on 4D tensors, so I could theoretically unfold the input volumes because they are 4D, but I cannot fold them back together because the output is a 5D tensor.

maybe consider reshaping the tensors after unfolding, I did something like this once,

x = torch.randn(5, 5, 5, 5)
a = nn.Fold((5, 5), (3, 3))
d = x.unfold(2, 3, 1).unfold(3, 3, 1).reshape(x.size(0), -1, 9)
a(d).shape

torch.Size([5, 5, 5, 5])
to get the same shape, after unfolding then folding, pytorch version of fold, adds overlapping values.

Hey, @whuang7000. You can probably use the patch-based pipelines in TorchIO for this.