After some time I pulled out the following function, which computes a sparse matrix from a nn.Conv
layer. The implementation could be improved by handling the padding properly, but padding the input images work (see the test script below).
def c2s(input_shape, weight, bias, stride=(1, 1), padding=(0, 0), dilation=(1,1), device='cpu', verbose=False, warns=True):
if dilation != (1,1):
raise RuntimeError('This functions does not account for dilation, if you extendent it, please send us a PR ;).')
if padding != (0, 0):
input_shape = input_shape[0:1] + torch.Size([x+2*y for x, y in zip(input_shape[-2:], padding)])
if warns: warn('Do not forget to pad your input accoding to the Conv2d padding. Deactivate this warning passing warns=False as argument.', stacklevel=2)
Cin, Hin, Win = input_shape
Cout = weight.shape[0]
Hk = weight.shape[2]
Wk = weight.shape[3]
kernel = weight
Hout = int(np.floor((Hin - dilation[0]*(Hk - 1) -1)/stride[0] + 1))
Wout = int(np.floor((Win - dilation[1]*(Wk - 1) -1)/stride[1] + 1))
shape_out = torch.Size((Cout*Hout*Wout, Cin*Hin*Win+1))
crow = (torch.linspace(0, shape_out[0], shape_out[0]+1)*(Hk*Wk*Cin+1)).int()
nnz = crow[-1]
# getting columns
cols = torch.zeros(Cout*Hout*Wout, Cin*Hk*Wk+1, dtype=torch.int)
data = torch.zeros(Cout*Hout*Wout, Cin*Hk*Wk+1)
base_row = torch.zeros(Cin*Hk*Wk, dtype=torch.int)
for cin in range(Cin):
c_shift = cin*(Hin*Win)
for hk in range(Hk):
h_shift = hk*Win
for wk in range(Wk):
idx = cin*Hk*Wk+hk*Wk+wk
w_shift = wk
base_row[idx] = c_shift+h_shift+w_shift
for cout in range(Cout):
k = kernel[cout]
_d = torch.hstack((k.flatten(), bias[cout]))
for ho in range(Hout):
h_shift = ho*Win*stride[0]
for wo in range(Wout):
w_shift = wo*stride[1]
idx = cout*Hout*Wout+ho*Wout+wo
shift = h_shift+w_shift
cols[idx,:-1] = base_row+shift
data[idx] = _d
# add bias as the last column
cols[:,-1] = Cin*Hin*Win
cols = cols.flatten()
data = data.flatten()
csr_mat = torch.sparse_csr_tensor(crow, cols, data, size=shape_out, device=device)
return csr_mat
and to test it:
import torch
from models.conv2d_to_sparse import conv2d_to_sparse as c2s
from time import time
from numpy.random import randint as ri
from torch.nn.modules.utils import _reverse_repeat_tuple
from torch.nn.functional import pad
if __name__ == '__main__':
use_cuda = torch.cuda.is_available()
cuda_index = torch.cuda.device_count() - 2
device = torch.device(f"cuda:{cuda_index}" if use_cuda else "cpu")
for i in range(30):
nc = ri(2, 20) # n channels
kw = ri(2, 10) # kernel width
kh = ri(2, 10) # kernel height
iw = ri(10, 50) # image width
ih = ri(10, 50) # image height
ns = 1 # n samples
cic = nc # conv in channels
coc = ri(2, 20) # conv out channels
sh = ri(2 ,10)
sw = ri(2, 10)
ph = ri(2, 10)
pw = ri(2, 10)
print('\n-------------------------')
print('cic, coc: ', cic, coc)
print('kernel h, w: ', kh, kw)
print('image h, w: ', ih, iw)
print('stride h, w: ', sh, sw)
print('padding h, w: ', ph, pw)
c = torch.nn.Conv2d(cic, coc, (kh, kw), stride=(sh, sw), dilation=(1,1), padding=(ph,pw))
w = c.weight
b = c.bias
x = torch.rand(ns, nc, ih, iw)
r = c(x).to(device)
t0 = time()
# pad input image
pad_mode = c.padding_mode if c.padding_mode != 'zeros' else 'constant'
x_pad = pad(x, pad=_reverse_repeat_tuple(c.padding, 2), mode=pad_mode)
my_csr = c2s(x[0].shape, w, b, stride=c.stride, padding=c.padding, dilation=c.dilation, device=device)
#'''
print('SVDing')
s, v, d = torch.svd_lowrank(my_csr, q=300)
#print('v:', v)
print('SVDone')
#'''
t_curr = time()-t0
lc = my_csr.to_dense()
xu = torch.hstack((x_pad.flatten(), torch.ones(1))).to(device)
ru = lc@xu
error = torch.norm(r-ru.reshape(r.shape))/torch.norm(r)
print('error ru: ', error)
print('time: ', t_curr)
if error > 1.0:
raise RuntimeError('go debug that conv.')
print('-------------------------\n')