Customized Convolution with Dynamic Kernels

Hi all,

I am trying to implement a customized convolution. I have two inputs, image and x, with shape 1 x N x N and batch_size x 1 x N x N, respectively. The kernels depend on image, and will be applied to x. That means, we learn a map (9 x 9 linear map) from a 3 x 3 patch of the image to a 3 x 3 kernel. Then apply that kernel to the corresponding patch of x.

The following is my current implementation. I have two versions: the forward version takes up too much memory and is relatively slow compared to a built-in Conv2D, the forward1 version saves memory, but due to naive for-loops. it is VERY slow. How can I make it fast with limited CUDA memory?

class SmallSMBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.KL = nn.Conv2d(1, 9, kernel_size=3, padding='same', bias=True)
    def forward1(self, image, x): # 1 x N x N, bs x 1 x N x N
        N = x.shape[-1]
        y = torch.zeros_like(x)
        image = F.pad(image, (1,)*4)
        x = F.pad(x, (1,)*4)
        for i in range(1, N+1):
            for j in range(1, N+1):
                K = ((self.KL.weight.squeeze() * image[0, i-1:i+2, j-1:j+2])).sum(dim=(1, 2)) + self.KL.bias
                K = K.view((3, 3))
                xx = x[:, 0, i-1:i+2, j-1:j+2]
                y[:, 0, i-1, j-1] = (xx * K).sum(dim=(-2, -1))
        return y
    def forward(self, image, x): # 1 x N x N, bs x 1 x N x N
        K = self.KL(image) # 1 x N x N -> 9 x N x N
        K = K.permute((1, 2, 0)) # 9 x N x N -> N x N x 9
        K = K.unflatten(2, (3, 3)) # N x N x 9 -> N x N x 3 x 3
        x = F.pad(x, (1, 1, 1, 1)) # bs x 1 x N x N -> bs x 1 x (N+2) x (N+2)
        x = x.unfold(2, 3, 1).unfold(3, 3, 1) # bs x 1 x (N+2) x (N+2) -> bs x 1 x N x N x 3 x 3
        y = (x * K).sum(dim=(-2, -1))
        return y