# How to split a (C x H x W) tensor into tiles?

With a tensor (C x H x W), I’d like to be able to split a image tensor (C x H x W) into tiles, do something with the tiles, and then put the tiles back together to recreate the original tensor.

I’ve put together a solution that seems to work like how I want it to. One function divides the tensor into tiles, and the other puts all the tiles back together:

``````import torch
import math

def roll_tensor(tensor, h_shift=None, w_shift=None):
if h_shift == None:
h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1)).item()
if w_shift == None:
w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2)).item()
if tensor.dim() == 3:
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=1), shifts=w_shift, dims=2)
return tensor, h_shift, w_shift

def split_tensor(tensor, tile_size=256, offset=256):
tiles = []
h, w = tensor.size(1), tensor.size(2)
for y in range(int(math.ceil(h/offset))):
for x in range(int(math.ceil(w/offset))):
tiles.append(tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)])
if tensor.is_cuda:
base_tensor = torch.zeros(tensor.size(), device=tensor.get_device())
else:
base_tensor = torch.zeros(tensor.size())
return tiles, base_tensor

def rebuild_tensor(tensor_list, base_tensor, tile_size=256, offset=256):
num_tiles = 0
h, w = base_tensor.size(1), base_tensor.size(2)
for y in range(int(math.ceil(h/offset))):
for x in range(int(math.ceil(w/offset))):
base_tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)] = tensor_list[num_tiles]
num_tiles+=1
return base_tensor

# Example:
image_tensor = torch.randn(3, 512, 405)
rolled_tensor, h_shift, w_shift = roll_tensor(image_tensor) # randomly shift tensor to prevent tile borders from forming
tiles, base = split_tensor(rolled_tensor) # split tensor into tiles
# Do stuff to tiles here
rebuilt_tensor = rebuild_tensor(tiles, base) # put tiles back together
unrolled_tensor, _, _ = roll_tensor(rebuilt_tensor, -h_shift, -w_shift) # Undo the above random shifts
``````

I’m not sure if this code can be improved?

1 Like

how about not randomizing? and just keeping it in an order?

@nile649 I’m not exactly sure what you mean, but just remove the tensor rolling functions if you want to avoid the randomization:

``````rolled_tensor, h_shift, w_shift = roll_tensor(image_tensor) # randomly shift tensor to prevent tile
unrolled_tensor, _, _ = roll_tensor(rebuilt_tensor, -h_shift, -w_shift) # Undo the above random shifts
``````