Yes, you’ve to reshape the tensor back to the original shape with,
.permute(1, 2, 3, 0).reshape(1, 3*p_size*p_size, patches_shape[-1])
to get sth similar to torch.Size([1, 49152, 225])
which is required for the fold function. To get the same behavior as before you have to set the stride to stride = tile_size.
And I literally used the exact same code you posted above, just loading a random image and rebuilding it.