Expand dense kernel into a dilated one with zero holes

I’m trying to expand a dense tensor into a dilated one with zero weights to fill the new holes. Suppose I have a [N,K,K] tensor, where N is the number of features and K is the kernel size. A dilated kernel with dilation D will have effective kernel size L=K*(K-1)*(D-1). I want to map [N,K,K] -> [N,L,L], which is essentially blowing up the kernel along the first dimension to create the dilated pattern, while setting the newly inserted entries to zero.

Is there an easy way to perform such an operation in PyTorch? Right now I’m manually figuring out all non-zero dilation indices by looping over the effective kernel window, which is quite inefficient. I’m thinking there’s a way to do this with a smart combination of expand/reshape, but I’m not seeing it.

Hi,

I think you can construct a transformation matrix, and use a matrix multiplication to implement it. I can show you an example of how to insert 0 in rows and columns. Suppose your K=2, weights are ‘A, B, C, D’, the transformations are like (hopefully you can get my point)


And you can use the combination of them for your task.

Correct me if I’m wrong :slight_smile:

Awesome. I worked it out and it ended up being the matrix you get from keeping every other D rows of an identity matrix of size L. Here’s an implementation and visualization.

import torch
import matplotlib.pyplot as plt

kernel_size = 7
dilation = 3

x = torch.rand(kernel_size, kernel_size)
eff_kernel_size = kernel_size + (kernel_size-1)*(dilation-1)

h = torch.eye(eff_kernel_size)[::dilation]
d = torch.mm(torch.mm(x.t(), h).t(), h)

fig = plt.figure()
ax1 = fig.add_subplot(1,2,1)
ax1.matshow(x.numpy(), interpolation='nearest')
ax1.title.set_text('Dense')
ax2 = fig.add_subplot(1,2,2)
ax2.matshow(d.numpy(), interpolation='nearest')
ax2.title.set_text('Dilated')
plt.show()

EDIT: Since PyTorch 1.9, there is torch.kron() which performs the Kronecker product. So you can get a pretty good speed-up by simply running:

h = torch.zeros(dilation, dilation)
h[0,0] = 1
d = torch.kron(x, h)[:eff_kernel_size, :eff_kernel_size]

1 Like