Fastest way to shuffle 3D tensor

Hi, new to pytorch here so this might be very basic.
I have a batch of volumes of the shape BxCxWxHxD, with W=H=D=160, batch B and channels C.
What I’m trying to do is divide each volume into smaller blocks, shuffle them, and rebuild the volume. Depending on the size of the blocks, the performance using indices can be quite slow. Therefore I’ve been trying to find faster, native approaches.
I found this thread How to extract smaller image patches (3D)? and I managed to use it to divide the volumes very efficiently. This is my modification of @ptrblck’s suggestion:

W = 160 # width
H = 160 # height
D = 160 # depth
batch_size = 2
channel_size = 1

x = torch.randn(batch_size, channel_size, W, H, D)
print(x.shape) # (2, 1, 160, 160, 160)
size = 80 # patch size
stride = 80 # patch stride
patches = x.unfold(2, size, stride).unfold(3, size, stride).unfold(4, size, stride)
patches = patches.reshape(2, 1, -1, 80, 80, 80)
print(patches.shape) # (2, 1, 8, 80, 80, 80)

For the shuffle I can just shuffle the 2nd dimension. However, no I don’t know how to reconstruct the volume to the original shape.
Thanks in advance!