Is interpolated-dilation convolution a thing?

In a vanilla dilated convolution kernel the sample points (weights) have empty space among them.
The purpose of employing dilation is to increase the receptive field of the convolution while keeping the number of weights low.
I was wandering if it is possible to linearly interpolate the space among the weights of a dilated convolution kernel as to fill the empty space. In addition, further increase the receptive field by letting the interpolation fade out to zero around the edges of the canonical sampling square/window. For example, a 3x3 vanilla kernel could be interpodilated to 5x5 and then even to 7x7, if such off-edge fade off is included, all while retaining only the same puny amount of 9 weights.

This approach should be doable by creating a custom nn.Module with the 3x3 weight parameter (and bias) in its __init__. Inside the forward you could then create the interpolated kernel (the 5x5 or 7x7 kernel) and apply it via the functional API F.conv2d to the input.

Ok I think I managed. Code of the class below. It’s a drop-in replacement for the cifar10 classifier tutorial.
If I comment out the bloom_2x call, training and performance are exactly the same, so I think it’s at least superficially correct. If I allow bloom_2x to do its work, the loss appears to drop quicker initially but it seems to flatten on a worse plateau when compared to vanilla 3x3. Maybe the initialization would have to be tweaked as well, instead of just using the default from self.conv1. I do not however have the resources to optimize that.
Oh, there’s also the line weights = torch.Tensor([[[[0.25,0.5,0.25], [0.5,1.0,0.5], [0.25,0.5,0.25]]]]). That’s another filthy can of hyperparameters.
I suspect it might do better on large image inputs paired with large stride values.
One thing I didn’t consider in my first post, of course, was that computing a 7x7 would be slower than 3x3. Just wanted to mention that.
I’m rambling now I need to go to sleep.

def bloom_2x(original_kernel):
	# original_kernel must be of shape (channels_out, channels_in, kernel_size, kernel_size)
	origsize = original_kernel.size()
	original_kernel_size = origsize[-1]
	original_kernel = original_kernel.view(1,-1,original_kernel_size,original_kernel_size)
	original_channels_amount = original_kernel.size()[1]
	original_kernel = original_kernel.view(original_channels_amount, 1, original_kernel_size,original_kernel_size)
	weights = torch.Tensor([[[[0.25,0.5,0.25], [0.5,1.0,0.5], [0.25,0.5,0.25]]]])
	new_kernel_size = original_kernel_size*2+1
	res = F.conv_transpose2d(original_kernel, weights, stride=2).view(1,original_channels_amount,new_kernel_size,new_kernel_size)
	return res.view(origsize[0],origsize[1],new_kernel_size,new_kernel_size)

class Net(nn.Module):
	def __init__(self):
		super().__init__()
		
		self.conv1 = nn.Conv2d(3, 35, 3, stride=1, padding=1)
		self.pool = nn.MaxPool2d(2, 2)
		self.conv2 = nn.Conv2d(35, 35, 3, stride=1, padding=1)
		self.fc1 = nn.Linear(35 * 8*8, 10)
		
		self.register_parameter(name='bloom_weights', param=torch.nn.Parameter( self.conv1.weight.data ))
		self.register_parameter(name='bloom_biases', param=torch.nn.Parameter( self.conv1.bias.data ))

	def forward(self, x):
		bloom_kernel = bloom_2x(self.bloom_weights)
		#bloom_kernel = self.bloom_weights
		x = self.pool(F.relu( torch.nn.functional.conv2d(x, bloom_kernel, self.bloom_biases, stride=1, padding=(bloom_kernel.size()[-1]-1)//2) ))
		#x = self.pool(F.relu(self.conv1(x)))
		x = self.pool(F.relu(self.conv2(x)))
		x = torch.flatten(x, 1)
		x = self.fc1(x)
		return x
1 Like