Fast differentiable blurring

How can I efficiently blur a binary image using a floating-point parameter for the amount of blur?

This might work:

import torch.nn.functional as F
import torch

def make_gaussian_kernel(kernel_size, sigma):
    ts = torch.linspace(-kernel_size // 2, kernel_size // 2 + 1, kernel_size)
    gauss = torch.exp((-(ts / sigma)**2 / 2))
    kernel = gauss / gauss.sum()
    return kernel

def fast_gaussian_blur(img: torch.Tensor, sigma: float) -> torch.Tensor:
    trailing_dims = img.shape[:-3]
    kernel_size = int(sigma * 5)
    if kernel_size % 2 == 0:
        kernel_size += 1
    kernel = make_gaussian_kernel(kernel_size, sigma)

    # padding = (left, right, top, bottom)
    padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
    img = F.pad(img, padding, mode="constant", value=0)

    # Separable 2d conv
    kernel = kernel.view(*trailing_dims, 1, kernel_size, 1)
    img = F.conv1d(img, kernel)
    kernel = kernel.view(*trailing_dims, 1, 1, kernel_size)
    img = F.conv1d(img, kernel)

    return img

I have not fully tried it out yet. Any and all feedback is appreciated :slight_smile: