Overlap two tensors of different size based on an offset in PyTorch (or how to do PIL.paste but with tensors)

I have the following structure:

torch.Size([channels, width, height])

Let’s say I have a tensor a

torch.Size([4, 512, 512])

And tensor b

torch.Size([4, 100, 100])

What I would like to do is to create a tensor c that is the result of “placing” tensor b on an arbitrary (width, height) coordinate offset of tensor a. For example, let’s say I would like to place tensor b on (300,100) of tensor a

So for tensor a's width between the 300-400 position, the values on tensor a should be replaced by the 100 values of tensor b width.

For tensor a's height between the 100-200 position, the values of tensor a should be replaced by the 100 values of tensor b height.

I would also like to choose for which channels I want to do this substitution and for which channels I would keep tensor a's value

(PS: The image is just an easy to illustrate example, but I would like to do it in a more generalisable way, so I’m no interested in converting to PIL, using PIL.paste and back to tensor but I would like to do all operations directly with Tensors)

You could access those indices directly and change the value.

# Size of new_values have to match the patch selected from 'a'
a[channels, y1:y2, x1:x2] = new_values 

For something more general, you can do something like this ↓

def paste(a, b, x1, y1, channels=None):
    assert len(a.shape) >= 3, f"Expected at least 3 dimensions for tensor 'a', got {len(a.shape)}"
    assert len(b.shape) == 3 or len(b.shape) == 2, f"Expected [2 or 3] dimensions for tensor 'b', got {len(a.shape)}"

    channels = channels if not channels is None else list(range(a.shape[-3]))
    if len(b.shape) == 3:
        assert a.shape[-3]==b.shape[-3] or len(channels)==b.shape[-3] or b.shape[-3]==1, "tensors a and b must have the same number of channels or 'b' 1."


    # Patch size
    _h, _w = b.shape[-2], b.shape[-1]

    h = _h if y1 + _h < a.shape[-2] else a.shape[-2] - y1
    w = _w if x1 + _w < a.shape[-1] else a.shape[-1] - x1

    if len(b.shape) == 3  and a.shape[-3] == b.shape[-3]:
        a[..., channels, y1:y1+h, x1:x1+w] = b[channels, :h, :w]  
    else:
        a[..., channels, y1:y1+h, x1:x1+w] = b[..., :h, :w]

    return a

Here you can give an image (or a batch of images) and the patch b to be pasted into the image(s).
It does not matter if the patch is too big and this should take car of it.

Also, you can select which channels you want to apply this to, else it will be applied to every channel.

b can have the same number of channels as a or as the channels mask or just be 1 and it will be cast to every channel.

Testing

from PIL import Image
import requests
from io import BytesIO
import torch

url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcQiFzoiIq7R99280lBVEcLAHRpb3LTz0ybB0A&usqp=CAU"
response = requests.get(url)
img = Image.open(BytesIO(response.content))

display(img)

img_t = torchvision.transforms.ToTensor()(img)

x1, y1 = 150, 65
b = torch.ones(3, 30, 80)
mask = [1]

img_p = paste(img_t, b, x1, y1, channels=mask)
display(torchvision.transforms.ToPILImage()(img_p))
  • Image before patching
    image

  • Image after patching
    image

If this does not solve your use-case please let me know.