DWC to Torch Linear or MatMul

Hello,

Recently, I read “Flatten Transformer: Vision Transformer using Focused Linear Attention”([2308.00442] FLatten Transformer: Vision Transformer using Focused Linear Attention) and I became curious about the method mentioned in the paper to transform Depthwise Convolution into a simple Matmul operation. According to the paper, a DWC operation with a (192, 14, 14) input and a (5, 5) kernel can be replaced with a (3, 196, 196) @ (192, 14, 14) form. However, I am having difficulty implementing this method.

def depthwise_conv2d_matmul(input, weight, bias=None, stride=1, padding=2, dilation=1):
    bsz, channels, h, w = input.shape
    k_channels, _, k_h, k_w = weight.shape

    assert h == w, 'Input tensor must be square'
    assert channels == k_channels, 'Number of input channels and kernel channels must match'
    assert k_h == k_w, 'Kernel must be square'
    input_size = h
    kernel_size = k_h

    # Split the input tensor and weight tensor along the channel dimension
    input_splits = input.split(1, dim=1)
    weight_splits = weight.split(1, dim=0)
    #print(input_splits[0].shape)
    output_splits = []
    for i in range(channels):
        # Unfold the input, input_unf.shape:                                torch.Size([bsz, kernel_size*kernel_size, window_size])
        input_unf = F.unfold(input_splits[i], weight_splits[i].shape[-2:], dilation=dilation, padding=padding,
                             stride=stride)
        #print(input_unf.shape)
        # Perform depth-wise convolution
        # input_unf.transpose(1, 2) shape:                                  torch.Size([bsz, window_size, kernel_size*kernel_size])
        # weight_splits[i].view(weight_splits[i].shape[0], -1).t() shape:   torch.Size([kernel_size*kernel_size, 1])
        # out_unf.shape:                                                    torch.Size([bsz, 1, window_size])

        out_unf = input_unf.transpose(1, 2).matmul(
            weight_splits[i].view(weight_splits[i].shape[0], -1).t()
        ).transpose(1, 2)
        # If bias is not None, add bias
        if bias is not None:
            out_unf += bias[i].view(1, -1, 1)

        # Fold the output tensor and add it to the list of output splits
        combined_out = F.fold(out_unf, (input_size + (2 * padding) - (dilation * (kernel_size - 1)) - 1) // stride + 1,
                              (1, 1))

        output_splits.append(combined_out)

    # Concatenate the output splits along the channel dimension to get the final output
    return torch.cat(output_splits, dim=1)

I have found examples that use unfold to partially modify DWC into Matmul, but how can I use this to modify it into a simple two-matrix multiplication as described above?