Hello,
Recently, I read “Flatten Transformer: Vision Transformer using Focused Linear Attention”([2308.00442] FLatten Transformer: Vision Transformer using Focused Linear Attention) and I became curious about the method mentioned in the paper to transform Depthwise Convolution into a simple Matmul operation. According to the paper, a DWC operation with a (192, 14, 14) input and a (5, 5) kernel can be replaced with a (3, 196, 196) @ (192, 14, 14) form. However, I am having difficulty implementing this method.
def depthwise_conv2d_matmul(input, weight, bias=None, stride=1, padding=2, dilation=1):
bsz, channels, h, w = input.shape
k_channels, _, k_h, k_w = weight.shape
assert h == w, 'Input tensor must be square'
assert channels == k_channels, 'Number of input channels and kernel channels must match'
assert k_h == k_w, 'Kernel must be square'
input_size = h
kernel_size = k_h
# Split the input tensor and weight tensor along the channel dimension
input_splits = input.split(1, dim=1)
weight_splits = weight.split(1, dim=0)
#print(input_splits[0].shape)
output_splits = []
for i in range(channels):
# Unfold the input, input_unf.shape: torch.Size([bsz, kernel_size*kernel_size, window_size])
input_unf = F.unfold(input_splits[i], weight_splits[i].shape[-2:], dilation=dilation, padding=padding,
stride=stride)
#print(input_unf.shape)
# Perform depth-wise convolution
# input_unf.transpose(1, 2) shape: torch.Size([bsz, window_size, kernel_size*kernel_size])
# weight_splits[i].view(weight_splits[i].shape[0], -1).t() shape: torch.Size([kernel_size*kernel_size, 1])
# out_unf.shape: torch.Size([bsz, 1, window_size])
out_unf = input_unf.transpose(1, 2).matmul(
weight_splits[i].view(weight_splits[i].shape[0], -1).t()
).transpose(1, 2)
# If bias is not None, add bias
if bias is not None:
out_unf += bias[i].view(1, -1, 1)
# Fold the output tensor and add it to the list of output splits
combined_out = F.fold(out_unf, (input_size + (2 * padding) - (dilation * (kernel_size - 1)) - 1) // stride + 1,
(1, 1))
output_splits.append(combined_out)
# Concatenate the output splits along the channel dimension to get the final output
return torch.cat(output_splits, dim=1)
I have found examples that use unfold to partially modify DWC into Matmul, but how can I use this to modify it into a simple two-matrix multiplication as described above?