I’ve put together a solution that seems to work like how I want it to. One function divides the tensor into tiles, and the other puts all the tiles back together:
import torch
import math
def roll_tensor(tensor, h_shift=None, w_shift=None):
if h_shift == None:
h_shift = torch.LongTensor(10).random_(-tensor.size(1), tensor.size(1))[0].item()
if w_shift == None:
w_shift = torch.LongTensor(10).random_(-tensor.size(2), tensor.size(2))[0].item()
if tensor.dim() == 3:
tensor = torch.roll(torch.roll(tensor, shifts=h_shift, dims=1), shifts=w_shift, dims=2)
return tensor, h_shift, w_shift
def split_tensor(tensor, tile_size=256, offset=256):
tiles = []
h, w = tensor.size(1), tensor.size(2)
for y in range(int(math.ceil(h/offset))):
for x in range(int(math.ceil(w/offset))):
tiles.append(tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)])
if tensor.is_cuda:
base_tensor = torch.zeros(tensor.size(), device=tensor.get_device())
else:
base_tensor = torch.zeros(tensor.size())
return tiles, base_tensor
def rebuild_tensor(tensor_list, base_tensor, tile_size=256, offset=256):
num_tiles = 0
h, w = base_tensor.size(1), base_tensor.size(2)
for y in range(int(math.ceil(h/offset))):
for x in range(int(math.ceil(w/offset))):
base_tensor[:, offset*y:min(offset*y+tile_size, h), offset*x:min(offset*x+tile_size, w)] = tensor_list[num_tiles]
num_tiles+=1
return base_tensor
# Example:
image_tensor = torch.randn(3, 512, 405)
rolled_tensor, h_shift, w_shift = roll_tensor(image_tensor) # randomly shift tensor to prevent tile borders from forming
tiles, base = split_tensor(rolled_tensor) # split tensor into tiles
# Do stuff to tiles here
rebuilt_tensor = rebuild_tensor(tiles, base) # put tiles back together
unrolled_tensor, _, _ = roll_tensor(rebuilt_tensor, -h_shift, -w_shift) # Undo the above random shifts
I’m not sure if this code can be improved?