Would it be possible to post the code that you ran on a test image successfully?
I have this function working correctly:
def split_tensor(tensor, tile_size=256):
mask = torch.ones_like(tensor)
# use torch.nn.Unfold
stride = tile_size//2
unfold = nn.Unfold(kernel_size=(tile_size, tile_size), stride=stride)
# Apply to mask and original image
mask_p = unfold(mask)
patches = unfold(tensor)
patches = patches.reshape(3, tile_size, tile_size, -1).permute(3, 0, 1, 2)
tiles = []
for t in range(patches.size(0)):
tiles.append(patches[[t], :, :, :])
return tiles, mask_p