I have an implementation using tensors. It is slow. What takes around 5s with Conv2d takes around 116s with my custom Conv2dPS.
Are there any obvious performance improvements I can make to this?
class Conv2dPS(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, padding=0):
super(Conv2dPS, self).__init__()
# Initialize parameters for the convolution layer
# in_channels: number of input channels
# out_channels: number of output channels
# kernel_size: size of the convolution kernel
# Define A (weights) for the convolution layer
# Use nn.Parameter to create a learnable parameter for the weights
self.A = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
# Define B (offset) as a learnable parameter
# B should have the same dimensions as the kernel
self.B = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
# Initialize A and B with appropriate initialization methods, e.g., Xavier initialization
nn.init.xavier_uniform_(self.A)
nn.init.xavier_uniform_(self.B)
# Set padding for the convolution operation
self.padding = padding
def convolve(self, x):
batch_size = x.shape[0]
image_channels = x.shape[1]
image_height = x.shape[2]
image_width = x.shape[3]
out_channels = self.A.shape[0]
in_channels = self.A.shape[1]
kernel_height = self.A.shape[2]
kernel_width = self.A.shape[3]
assert(image_channels == in_channels)
assert(kernel_height == kernel_width)
# F.unfold takes an input tensor and extracts sliding local blocks (or patches) from it.
# These blocks are the regions of the input tensor over which the convolution operation
# (filter application) will take place.
# x_unfolded, will have a shape of [batch_size, in_channels * kernel_height * kernel_width, num_patches]
# The output will look something like this:
# tensor([[[ 1., 2., 3., 5., 6., 7., 9., 10., 11.],
# [ 2., 3., 4., 6., 7., 8., 10., 11., 12.],
# [ 5., 6., 7., 9., 10., 11., 13., 14., 15.],
# [ 6., 7., 8., 10., 11., 12., 14., 15., 16.]]])
# The first patch is the first element of each row: [1, 2, 5, 6]
x_unfolded = F.unfold(x, kernel_size=kernel_height, padding=self.padding)
unfolded_batch_size = x_unfolded.shape[0]
unfolded_patch_size = x_unfolded.shape[1]
num_patches = x_unfolded.shape[2]
assert(unfolded_batch_size == batch_size)
assert(unfolded_patch_size == in_channels * kernel_height * kernel_width)
# Reshape x_unfolded into a format that aligns with the convolution weights A
# transpose dimensions 1 and 2 above into [batch, num_patches, in_channels * kernel_height * kernel_width]
# then view as [batch, num_patches, in_channels, kernel_height, kernel_width]
x_unfolded = x_unfolded.permute(0, 2, 1).view(batch_size, num_patches, in_channels, kernel_height, kernel_width)
# Expand x_unfolded across output_channels to match the dimensions of B_expanded
x_expanded = x_unfolded.unsqueeze(2).expand(batch_size, num_patches, out_channels, in_channels, kernel_height, kernel_width)
return x_expanded
def subtract_offset(self, x_convolve):
batch_size = x_convolve.shape[0]
num_patches = x_convolve.shape[1]
x_out_channels = x_convolve.shape[2]
x_in_channels = x_convolve.shape[3]
x_kernel_height = x_convolve.shape[4]
x_kernel_width = x_convolve.shape[5]
out_channels = self.B.shape[0]
in_channels = self.B.shape[1]
kernel_height = self.B.shape[2]
kernel_width = self.B.shape[3]
assert(x_out_channels == out_channels)
assert(x_in_channels == in_channels)
assert(x_kernel_height == kernel_height)
assert(x_kernel_width == kernel_width)
# Reshape B to match the dimensions of x_unfolded, but keeping its unique values per filter
# Current shape of B: [out_channels, in_channels, kernel_height, kernel_width]
# We need B to have the shape: [batch_size, in_channels, patch_size, num_patches]
# Expand B across the batch_size and num_patches dimensions
B_reshaped = self.B.view(1, 1, out_channels, in_channels, kernel_height, kernel_width)
B_expanded = B_reshaped.expand(batch_size, num_patches, out_channels, in_channels, kernel_height, kernel_width)
# Subtract B_expanded from each patch
x_offset = x_convolve - B_expanded
return x_offset
def multiply_weights(self, x_offset):
batch_size = x_offset.shape[0]
num_patches = x_offset.shape[1]
x_out_channels = x_offset.shape[2]
x_in_channels = x_offset.shape[3]
x_kernel_height = x_offset.shape[4]
x_kernel_width = x_offset.shape[5]
out_channels = self.A.shape[0]
in_channels = self.A.shape[1]
kernel_height = self.A.shape[2]
kernel_width = self.A.shape[3]
assert(x_out_channels == out_channels)
assert(x_in_channels == in_channels)
assert(x_kernel_height == kernel_height)
assert(x_kernel_width == kernel_width)
# Multiply A with x_offset and sum over the kernel dimensions
# A shape: [out_channels, in_channels, kernel_height, kernel_width]
# x_offset shape: [batch_size, num_patches, in_channels, kernel_height, kernel_width]
return self.A.unsqueeze(0).unsqueeze(1) * x_offset
def forward(self, x):
batch_size = x.shape[0]
image_channels = x.shape[1]
image_height = x.shape[2]
image_width = x.shape[3]
out_channels = self.A.shape[0]
in_channels = self.A.shape[1]
kernel_height = self.A.shape[2]
kernel_width = self.A.shape[3]
assert(image_channels == in_channels)
assert(kernel_height == kernel_width)
# 1. Extract image patches
x_convolve = self.convolve(x)
# 2. Calculate P_offset = P - B
x_offset = self.subtract_offset(x_convolve)
# 3. Perform the convolution operation y = A * P_offset
output = torch.sum(self.multiply_weights(x_offset), dim=[3, 4, 5])
# 4. Reshape the output
# Calculate the dimensions of the output feature map
output_height = (image_height + 2 * self.padding - (kernel_height - 1) - 1) + 1
output_width = (image_width + 2 * self.padding - (kernel_width - 1) - 1) + 1
# Reshape output to the shape (batch_size, out_channels, output_height, output_width)
output = output.view(batch_size, out_channels, output_height, output_width)
return output