How to split a (C x H x W) tensor into tiles?

With a tensor (C x H x W), I’d like to be able to split a image tensor (C x H x W) into tiles, do something with the tiles, and then put the tiles back together to recreate the original tensor.

How would I go about this?

I’ve put together a solution that seems to work like how I want it to. One function divides the tensor into tiles, and the other puts all the tiles back together:

import torch
import math

def roll_tensor(tensor, h_shift=None, w_shift=None):
    if h_shift == None:
       h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1))[0].item()
    if w_shift == None:
       w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2))[0].item()
    if tensor.dim() == 3:
        tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=1), shifts=w_shift, dims=2)
    return tensor, h_shift, w_shift

def split_tensor(tensor, tile_size=256, offset=256):
    tiles = []
    h, w = tensor.size(1), tensor.size(2)
    for y in range(int(math.ceil(h/offset))):
         for x in range(int(math.ceil(w/offset))):
              tiles.append(tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)])
    if tensor.is_cuda:
         base_tensor = torch.zeros(tensor.size(), device=tensor.get_device())
    else: 
         base_tensor = torch.zeros(tensor.size())	
    return tiles, base_tensor		
	
def rebuild_tensor(tensor_list, base_tensor, tile_size=256, offset=256):
    num_tiles = 0
    h, w = base_tensor.size(1), base_tensor.size(2)
    for y in range(int(math.ceil(h/offset))):
         for x in range(int(math.ceil(w/offset))):
              base_tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)] = tensor_list[num_tiles]
              num_tiles+=1			  
    return base_tensor


# Example: 
image_tensor = torch.randn(3, 512, 405)
rolled_tensor, h_shift, w_shift = roll_tensor(image_tensor) # randomly shift tensor to prevent tile borders from forming
tiles, base = split_tensor(rolled_tensor) # split tensor into tiles 
# Do stuff to tiles here
rebuilt_tensor = rebuild_tensor(tiles, base) # put tiles back together
unrolled_tensor, _, _ = roll_tensor(rebuilt_tensor, -h_shift, -w_shift) # Undo the above random shifts

I’m not sure if this code can be improved?

1 Like

how about not randomizing? and just keeping it in an order?

@nile649 I’m not exactly sure what you mean, but just remove the tensor rolling functions if you want to avoid the randomization:

rolled_tensor, h_shift, w_shift = roll_tensor(image_tensor) # randomly shift tensor to prevent tile 
unrolled_tensor, _, _ = roll_tensor(rebuilt_tensor, -h_shift, -w_shift) # Undo the above random shifts