Conv2d parameters as a matrix

I have a Conv2d layer, and I need to get a matrix from the layer weights that, when multiplied by the input x will give me the same result as applying the layer x. I have found unfolding-based solutions applied to the input, but in my case, I would like to get the matrix for the Conv2d parameters.

In other words, I need the function to compute the to_matrix() in the code below. I imagine some flattening being necessary for x, which would be fine.

from torch.nn import Conv2d
n_channels = 2
kernel_size =3
x_size = (4, 5)

layer = Conv2d(n_channels, n_out, kernel_size)
w = layer.weight
b = layer.bias

x = torch.rand((batch_size, n_channels)+x_size)

y = layer(conv)
my_y = torch.matmul(to_matrix(w, b), x)


Any help would be appreciated

From what I understand, suppose you have a tensor input x with a shape of [C, H, W], and after applying convolution to x, you get a tensor y with a shape of [Cout, Hout, Wout]. At this point, the weight of the conv2D has a shape of [C_out, C*k*k] where k is the kernel size. In this scenario, it’s evident that you would need C_out*H_out*W_out*C_out*C*k*k computations and only C_out*C*k*k parameters.

If you want to vectorize the conv2D operation, it would be similar to applying a dense layer. The number of computations required would be C*H*W*C_out*H_out*W_out computations as well as parameters. I am not sure why you would want to do this, but it clearly has increased the amount significantly.

Indeed I do not want to perform the torch.matmul(to_matrix(w, b), x) computation, instead, I want to get the matrix for creating a characterization of the matrix created with the layer weights to_matrix(w, b).

I’m not sure if this is what you’re looking but it seems related:

And this thread might help:

By the way, using a Toeplitz matrix may not be the most efficient approach. This post provides a more efficient algorithm with less calculations:

But still needs the operations to be parallelized.

After some time I pulled out the following function, which computes a sparse matrix from a nn.Conv layer. The implementation could be improved by handling the padding properly, but padding the input images work (see the test script below).

def c2s(input_shape, weight, bias, stride=(1, 1), padding=(0, 0), dilation=(1,1), device='cpu', verbose=False, warns=True):

    if dilation != (1,1):
        raise RuntimeError('This functions does not account for dilation, if you extendent it, please send us a PR ;).')
    if padding != (0, 0):
        input_shape = input_shape[0:1] + torch.Size([x+2*y for x, y in zip(input_shape[-2:], padding)])
        if warns: warn('Do not forget to pad your input accoding to the Conv2d padding. Deactivate this warning passing warns=False as argument.', stacklevel=2)

    Cin, Hin, Win = input_shape
    Cout = weight.shape[0]
    Hk = weight.shape[2]
    Wk = weight.shape[3]
    kernel = weight

    Hout = int(np.floor((Hin - dilation[0]*(Hk - 1) -1)/stride[0] + 1))
    Wout = int(np.floor((Win - dilation[1]*(Wk - 1) -1)/stride[1] + 1))

    shape_out = torch.Size((Cout*Hout*Wout, Cin*Hin*Win+1))
    crow = (torch.linspace(0, shape_out[0], shape_out[0]+1)*(Hk*Wk*Cin+1)).int()
    nnz = crow[-1]
    # getting columns
    cols = torch.zeros(Cout*Hout*Wout, Cin*Hk*Wk+1,
    data = torch.zeros(Cout*Hout*Wout, Cin*Hk*Wk+1)
    base_row = torch.zeros(Cin*Hk*Wk,
    for cin in range(Cin):
        c_shift = cin*(Hin*Win)
        for hk in range(Hk):
            h_shift = hk*Win
            for wk in range(Wk):
                idx = cin*Hk*Wk+hk*Wk+wk
                w_shift = wk
                base_row[idx] = c_shift+h_shift+w_shift
    for cout in range(Cout): 
        k = kernel[cout]
        _d = torch.hstack((k.flatten(), bias[cout]))
        for ho in range(Hout):
            h_shift = ho*Win*stride[0]
            for wo in range(Wout):
                w_shift = wo*stride[1]
                idx = cout*Hout*Wout+ho*Wout+wo
                shift = h_shift+w_shift
                cols[idx,:-1] = base_row+shift
                data[idx] = _d

    # add bias as the last column                    
    cols[:,-1] = Cin*Hin*Win

    cols = cols.flatten()
    data = data.flatten()

    csr_mat = torch.sparse_csr_tensor(crow, cols, data, size=shape_out, device=device)
    return csr_mat 

and to test it:

import torch
from models.conv2d_to_sparse import conv2d_to_sparse as c2s
from time import time
from numpy.random import randint as ri
from torch.nn.modules.utils import _reverse_repeat_tuple
from torch.nn.functional import pad

if __name__ == '__main__':
    use_cuda = torch.cuda.is_available()
    cuda_index = torch.cuda.device_count() - 2
    device = torch.device(f"cuda:{cuda_index}" if use_cuda else "cpu")

    for i in range(30):
        nc = ri(2, 20) # n channels
        kw = ri(2, 10) # kernel width 
        kh = ri(2, 10) # kernel height
        iw = ri(10, 50) # image width
        ih = ri(10, 50) # image height
        ns = 1 # n samples
        cic = nc # conv in channels
        coc = ri(2, 20) # conv out channels
        sh = ri(2 ,10)
        sw = ri(2, 10)
        ph = ri(2, 10) 
        pw = ri(2, 10) 
        print('cic, coc: ', cic, coc)
        print('kernel h, w: ', kh, kw)
        print('image h, w: ', ih, iw)
        print('stride h, w: ', sh, sw)
        print('padding h, w: ', ph, pw)
        c = torch.nn.Conv2d(cic, coc, (kh, kw), stride=(sh, sw), dilation=(1,1), padding=(ph,pw))
        w = c.weight
        b = c.bias
        x = torch.rand(ns, nc, ih, iw)
        r = c(x).to(device)

        t0 = time()
        # pad input image
        pad_mode = c.padding_mode if c.padding_mode != 'zeros' else 'constant'
        x_pad = pad(x, pad=_reverse_repeat_tuple(c.padding, 2), mode=pad_mode) 
        my_csr = c2s(x[0].shape, w, b, stride=c.stride, padding=c.padding, dilation=c.dilation, device=device)

        s, v, d = torch.svd_lowrank(my_csr, q=300)
        #print('v:',  v)
        t_curr = time()-t0

        lc = my_csr.to_dense()
        xu = torch.hstack((x_pad.flatten(), torch.ones(1))).to(device)
        ru = lc@xu
        error = torch.norm(r-ru.reshape(r.shape))/torch.norm(r) 
        print('error ru: ', error)
        print('time: ', t_curr) 
        if error > 1.0:
            raise RuntimeError('go debug that conv.')