Seemlessly blending tensors together?

I’m looking to seamlessly blend tensors together. Currently the code below separates a tensor into overlapping tiles, and then it puts the tiles back together. If any of the tiles are slightly modified, the boundaries between them will become extremely obvious when they’re put back together, and thus the overlapping regions also can’t just be added together either.

import torch
from PIL import Image
import torchvision.transforms as transforms

def split_tensor(tensor, tile_size=256, overlap=0):
    tiles = []
    h, w = tensor.size(2), tensor.size(3)
    for y in range(int(-(h // -tile_size))):       
        for x in range(int(-(w // -tile_size))):        
            y_val = max(min(tile_size*y+tile_size +overlap, h), 0)
            x_val = max(min(tile_size*x+tile_size +overlap, w), 0)
            ty = tile_size*y
            tx = tile_size*x

            if abs(tx - x_val) < tile_size:
                tx = x_val-tile_size
            if abs(ty - y_val) < tile_size:
                ty = y_val-tile_size
               
            tiles.append(tensor[:, :, ty:y_val, tx:x_val])
    if tensor.is_cuda:
        base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
    else: 
        base_tensor = torch.zeros(tensor.squeeze(0).size())
    return tiles, base_tensor.unsqueeze(0)  

def rebuild_tensor(tensor_list, base_tensor, tile_size=256, overlap=0):
    num_tiles = 0
    h, w = base_tensor.size(2), base_tensor.size(3)
    for y in range(int(-(h // -tile_size))):       
        for x in range(int(-(w // -tile_size))): 
            y_val = max(min(tile_size*y+tile_size +overlap, h), 0)
            x_val = max(min(tile_size*x+tile_size +overlap, w), 0)
            ty = tile_size*y
            tx = tile_size*x
        
            if abs(tx - x_val) < tile_size:
                tx = x_val-tile_size                
            if abs(ty - y_val) < tile_size:
                ty = y_val-tile_size
            
            base_tensor[:, :, ty:y_val, tx:x_val] = tensor_list[num_tiles] # Put tiles on base tensor made of zeros
            num_tiles+=1              
    return base_tensor

# Load image and 
test_image = 'test_image.jpg'
image_size=1024
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
input_tensor = Loader(Image.open(test_image).convert('RGB')).unsqueeze(0)

# Split image into overlapping tiles
tile_tensors, base_t = split_tensor(input_tensor, 512)


# Put tiles back together
output_tensor = rebuild_tensor(tile_tensors, base_t, 512)

# Save output to see result
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu().squeeze(0)).save('output_image.png')
1 Like

Hi,
I am not quite sure if I understood your problem correctly, but it seems the torch.nn.Unfold and Fold function is what you’re looking for. If you have the problem that overlapping tiles are added together you can create a mask and divide by it to obtain the original input.

For example:

img     = torch.ones(1, 3,1024,1024)
mask    = torch.ones_like(img) ## Create mask to correct for areas which are going to be added 
# use torch.nn.Unfold
p_size  = 128
stride  = p_size//2
unfold  = nn.Unfold(kernel_size=(p_size, p_size), stride=stride)
# Apply to mask and original image
mask_p  = unfold(mask)
patches = unfold(img)
patches.shape

outputs shape of torch.Size([1, 49152, 225]) - for using the tiles individually you’ve to reshape them:

patches.reshape(3, p_size, p_size, -1).permute(3, 0, 1, 2)

to get a shape of torch.Size([225, 3, 128, 128]). Merging the patches/tiles back to the original image:

fold     = nn.Fold(output_size=(img.shape[2], img.shape[3]), kernel_size=(p_size, p_size), stride=stride)
img_back = fold(patches)/fold(mask_p)

If you don’t divide by the mask you will end up with areas where overlapping patches/tiles are added together. It isn’t the fanciest way but it works reasonable for different strides and patch sizes.

@Donk10 If it helps to visualize it, I’m using the tensor spitting function with code for style transfer, to lower GPU usage. Currently this is what the output looks like using my code:
out_200

I don’t see exactly how I apply nn.fold/unfold to my tiling code?

I tested your code and it worked, I did not get these artifacts. Are you applying your model on these patches/tiles before rebuilding the image in the example image you showed?

The unfold function is equivalent to your split_tensor function with stride=tile_size and fold is similar to your rebuild function.

In your code, patches has a size of: torch.Size([1, 49152, 225]). I may just not be understanding exactly how unfold/fold works, but I’m not sure how I would use that with the steps I have below as it’s a single large tensor and not a list of patches?

I’m doing this with my model and the tiles:

  1. Split the image into tiles.
  2. Run each tile through a new copy of the model for a set number of iterations.
  3. Put all the tiles back together, maybe save the output, then split the output into tiles again.
  4. Repeat steps 2 and 3 for a set number of iterations.

The full code can be found here: https://gist.github.com/ProGamerGov/e64fcb309274c2946f5a9a679ed45669

Yes, you’ve to reshape patches to regain spatial dimensions,

patches.reshape(3, p_size, p_size, -1).permute(3, 0, 1, 2)

the output shape will be the number of patches x 3 x tile_size x tile_size. I used this code for segmentation because I had also the GPU constraint so I applied my model patchwise.

The code you posted worked for me after splitting and rebuilding a test image, so my guess would be either (i) you have to renormalize each tile/image (you could test that by saving the patches before rebuilding them with torch.utils.save_image(…, normalize=True)) or your models causes the problem, maybe you should try to have some overlap/stride to get ride of the boundaries.

I had to change this line to make it work correctly:

patches = patches.reshape(3, p_size, p_size, -1).permute(3, 0, 1, 2)

But now I get the following error message:

Traceback (most recent call last):
  File "fold.py", line 23, in <module>
    img_back = fold(patches)/fold(mask_p)
    ...
    raise NotImplementedError("Input Error: Only 3D input Tensors are supported (got {}D)".format(input.dim()))
NotImplementedError: Input Error: Only 3D input Tensors are supported (got 4D)

Specifically, running fold(patches) causes the error.

Would it be possible to post the code that you ran on a test image successfully?

I have this function working correctly:

def split_tensor(tensor, tile_size=256):
    mask = torch.ones_like(tensor)
	
    # use torch.nn.Unfold
    stride  = tile_size//2
    unfold  = nn.Unfold(kernel_size=(tile_size, tile_size), stride=stride)
    # Apply to mask and original image
    mask_p  = unfold(mask)
    patches = unfold(tensor)
	
    patches = patches.reshape(3, tile_size, tile_size, -1).permute(3, 0, 1, 2)

    tiles = []
    for t in range(patches.size(0)):
         tiles.append(patches[[t], :, :, :])
    return tiles, mask_p

Yes, you’ve to reshape the tensor back to the original shape with,

.permute(1, 2, 3, 0).reshape(1, 3*p_size*p_size, patches_shape[-1])

to get sth similar to torch.Size([1, 49152, 225]) which is required for the fold function. To get the same behavior as before you have to set the stride to stride = tile_size.

And I literally used the exact same code you posted above, just loading a random image and rebuilding it.

@Donk10 I have it working now! Thank you for the help!

import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms as transforms

def split_tensor(tensor, tile_size=256):
    mask = torch.ones_like(tensor)
    # use torch.nn.Unfold
    stride  = tile_size//2
    unfold  = nn.Unfold(kernel_size=(tile_size, tile_size), stride=stride)
    # Apply to mask and original image
    mask_p  = unfold(mask)
    patches = unfold(tensor)
	
    patches = patches.reshape(3, tile_size, tile_size, -1).permute(3, 0, 1, 2)
    if tensor.is_cuda:
        patches_base = torch.zeros(patches.size(), device=tensor.get_device())
    else: 
        patches_base = torch.zeros(patches.size())
	
    tiles = []
    for t in range(patches.size(0)):
         tiles.append(patches[[t], :, :, :])
    return tiles, mask_p, patches_base, (tensor.size(2), tensor.size(3))

def rebuild_tensor(tensor_list, mask_t, base_tensor, t_size, tile_size=256):
    stride  = tile_size//2  

    for t, tile in enumerate(tensor_list):
         print(tile.size())
         base_tensor[[t], :, :] = tile  
	 
    base_tensor = base_tensor.permute(1, 2, 3, 0).reshape(3*tile_size*tile_size, base_tensor.size(0)).unsqueeze(0)
    fold = nn.Fold(output_size=(t_size[0], t_size[1]), kernel_size=(tile_size, tile_size), stride=stride)
    output_tensor = fold(base_tensor)/fold(mask_t)
    return output_tensor

# Load image and 
test_image = 'test_image.jpg'
image_size=1024
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
input_tensor = Loader(Image.open(test_image).convert('RGB')).unsqueeze(0)


# Split image into overlapping tiles
tile_tensors, mask_t, base_tensor, t_size = split_tensor(input_tensor, 660)

# Put tiles back together
output_tensor = rebuild_tensor(tile_tensors, mask_t, base_tensor, t_size, 660)

# Save Output
Image2PIL = transforms.ToPILImage()
Image2PIL(output_tensor.cpu().squeeze(0)).save('output_image.png')

Now, I need to test it with the main code.

So, for som reason I’m getting a black border on the output image now:

The tile boundaries are also really visible:

When someone did something similar to what I’m doing, I think that they might have used Torch.linspace() and some sort of mask to blend tiles together. However, their code is written the original Lua/Torch7 library, so I’m not 100% sure: https://github.com/VaKonS/neural-style/blob/Multi-resolution/neural_style.lua

I think that blending the tiles together like this could work a lot better:

import torch
tile_tensor = torch.randn(6,6)

print(tile_tensor.size())
print(tile_tensor[0][2])

# Create mask to feather tile edges
overlap = tile_tensor.size(2)-2
lin_tensor = torch.linspace(0,1,overlap).repeat(tile_tensor.size(2),1)
lin_part = torch.ones(6, 2)
mask = torch.cat((lin_tensor, lin_part), 1)

# Apply mask to tile
for i, t in enumerate(tile_tensor[0]):
   tile_tensor[0][i] = tile_tensor[0][i] * mask

print(tile_tensor[0][2])

However, I need to figure out the overlapping regions on each tile for it to work, and I’m not quite sure how I’d do that.

In the diagram below, the red regions are where the Torch.linspace() mask will be applied. This strategy should be a lot easier to implement:

I’ve got this code together for the feathering, but it seems to have issues:

import torch
from PIL import Image
import torchvision.transforms as transforms

# Feather vertically          
def feather_tiles(tensor_list, rows, w_overlap):
    print(len(tensor_list))
    new_tensor_list = []
    if w_overlap > 0:
        for i, tile in enumerate(tensor_list):
            if i % rows[1] != 0:
                lin_mask = torch.linspace(0,1,w_overlap).repeat(tile.size(2),1)
                mask_part = torch.ones(tile.size(2), tile.size(3)-w_overlap)
                mask = torch.cat([lin_mask, mask_part], 1)
                mask = mask.repeat(3,1,1).unsqueeze(0)
                new_tensor_list.append(tile * mask)
            else:
                new_tensor_list.append(tile)
    else:
        new_tensor_list = tensor_list
    return new_tensor_list

# Horizontal feathering
def feather_rows(tensor_list, rows, h_overlap):
    if h_overlap > 0:
        for i, tile in enumerate(tensor_list):
            if i > rows[1]:
                for v in range(3):
                    lin_mask = torch.linspace(0,1,h_overlap).repeat(tile.size(3),1).rot90().rot90().rot90()
                    mask_part = torch.ones(tile.size(2)-h_overlap, tile.size(3))
                    mask = torch.cat((lin_mask, mask_part))	
                    tensor_list[i][0][v] = tensor_list[i][0][v] * mask					
    return tensor_list	

def split_tensor(tensor, tile_size=256):
    tiles = []
    tile_idx = []
    tile_size_y, tile_size_x = tile_size+8, tile_size +5
    h, w = tensor.size(2), tensor.size(3)
    h_range, w_range = int(-(h // -tile_size_y)), int(-(w // -tile_size_x))
	
    for y in range(h_range):       
        for x in range(w_range): 		
            y_val = max(min(tile_size_y*y+tile_size_y, h), 0)
            x_val = max(min(tile_size_x*x+tile_size_x, w), 0)
            ty = tile_size_y*y
            tx = tile_size_x*x

            if abs(tx - x_val) < tile_size_x:
                tx = x_val- tile_size_x
            if abs(ty - y_val) < tile_size_y:
                ty = y_val-tile_size_y


            tiles.append(tensor[:, :, ty:y_val, tx:x_val])
            tile_idx.append([ty, y_val, tx, x_val])
	
    w_overlap = tile_idx[0][3] - tile_idx[1][2]
    h_overlap = tile_idx[0][1] - tile_idx[w_range][0]

    if tensor.is_cuda:
        base_tensor = torch.zeros(tensor.squeeze(0).size(), device=tensor.get_device())
    else: 
        base_tensor = torch.zeros(tensor.squeeze(0).size())
    return tiles, base_tensor.unsqueeze(0), (h_range, w_range), (h_overlap, w_overlap)  

def build_row(tensor_tiles, hxw, w_overlap, bt, tile_size):
    if bt.is_cuda:
        row_base = torch.zeros(bt.size(1),tensor_tiles[0].size(2),bt.size(3), device=bt.get_device()).unsqueeze(0)
    else: 
        row_base = torch.zeros(bt.size(1),tensor_tiles[0].size(2),bt.size(3)).unsqueeze(0)
    row_list = []
    row_list_2 = []
    for v in range(hxw[1]):
      row_list.append(row_base) 
      row_list_2.append(row_base)     
        
    num_tiles = 0
    tile_size_y, tile_size_x = tile_size+8, tile_size +5
    h, w = bt.size(2), bt.size(3)
    h_range, w_range = hxw[0], hxw[1]
    for y in range(h_range):       
        for x in range(w_range):        
            y_val = max(min(tile_size_y*y+tile_size_y, h), 0)
            x_val = max(min(tile_size_x*x+tile_size_x, w), 0)
            ty = tile_size_y*y
            tx = tile_size_x*x

            if abs(tx - x_val) < tile_size_x:
                tx = x_val- tile_size_x
            if abs(ty - y_val) < tile_size_y:
                ty = y_val-tile_size_y
                
            print(row_list[y].size(), tensor_tiles[num_tiles].size())
            if x == 0:
                row_list[y][:, :, :, tx:x_val] = tensor_tiles[num_tiles]
            else:
                row_list_2[y][:, :, :, tx:x_val] = tensor_tiles[num_tiles]
                row_list[y] = row_list[y] + row_list_2[y] / row_list_2[y]               
            num_tiles+=1        
    return row_list

# Test Functions
test_image = 'test_image.jpg'
image_size=(256,202)
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
input_tensor = Loader(Image.open(test_image).convert('RGB')).unsqueeze(0)

tile_tensors, base_t, hxw, ovlp = split_tensor(input_tensor, 128)
feathered_tiles = feather_tiles(tile_tensors, hxw, ovlp[1])
row_tensors = build_row(feathered_tiles, hxw, ovlp[1], base_t, 128)
feathered_row_tensors = feather_rows(row_tensors , hxw, ovlp[0])

Any help with getting these two functions working properly, would be really appreciated!

This is what the first row looks like after the tiles are run through build_row() function:

ft_row_0