Fold an overlapping 3D tensor?

Problem:
Fold/unfold overlapping patches from 3D tensors.

Context:
I am working on a segmentation problem where I make a prediction for each patch, and then use mean pooling to combine predictions from overlapping patches.

Using code from here, non-overlapping patches works great, but I have been unable to adapt the code to overlapping patches.

import torch

# Unfold data
x = torch.randn(1, 256, 256, 256)  # batch, c, h, w
kc, kh, kw = 64, 64, 64  # kernel size
dc, dh, dw = 64, 64, 64  # stride
patches = x.unfold(1, kc, dc).unfold(2, kh, dh).unfold(3, kw, dw)
unfold_shape = patches.size()
patches = patches.contiguous().view(patches.size(0), -1, kc, kh, kw)

# Reshape back
patches_orig = patches.view(unfold_shape)
output_c = unfold_shape[1] * unfold_shape[4]
output_h = unfold_shape[2] * unfold_shape[5]
output_w = unfold_shape[3] * unfold_shape[6]
patches_orig = patches_orig.permute(0, 1, 4, 2, 5, 3, 6).contiguous()
patches_orig = patches_orig.view(1, output_c, output_h, output_w)

# Check for equality
print((patches_orig == x).all())

Is it possible to use stride=16 and use mean pooling to combine overlapping patches?

Cheers,
Brendan

The best option I found was to use Monai’s sliding window inference function. This was exactly what I was looking for.