Custom efficient Conv2d with intermediate calls to functions

Hello everyone! I’ve been working on a customized Conv2d layer which allows me to do operations on the data immediately after every multiplication and sum (I’m trying to prove that we can quantize a CNN to use less bits while keeping the accuracy, it’s for my masters). My starting point is the code available here, where unfold and fold is used for that.
The code is:

# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    batch_size, in_channels, in_h, in_w = input.shape
    out_channels, in_channels, kh, kw =  weight.shape

    unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
    inp_unf = unfold(input)

    if bias is None:
        out_unf = inp_unf.transpose(1, 2).matmul(weight.view(weight.size(0), -1).t()).transpose(1, 2)
    else:
        out_unf = (inp_unf.transpose(1, 2).matmul(w_) + bias).transpose(1, 2)
    out = out_unf.view(batch_size, out_channels, out_h, out_w)
    return out

The problem is that I need to do call a function immediately after every multiplication and sum, so I could not use the .matmul() method since it does both for me. My first attempt was to do the operations independently, calling myfunc() after each one of them:

# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    batch_size, in_channels, in_h, in_w = input.shape
    out_channels, in_channels, kh, kw =  weight.shape

    unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
    inp_unf = unfold(input)

    weight_t = weight.view(weight.size(0), -1).t()
    inp_unf_t = inp_unf.transpose(1, 2)
    inp_unf_t = inp_unf_t[:,:,:, None]
    inp_unf_t_exp = inp_unf_t.transpose(2,3).repeat((1,1,weight_t.t().size(0),1))
    weight_t_exp = weight_t.t().repeat((1,inp_unf_t.size(1),1,1))
    out_unf = myfunc(inp_unf_t_exp.mul(weight_t_exp))
    out_unf = myfunc(torch.sum(out_unf, dim=3).transpose(1, 2))
    if bias is not None:
        out_unf = out_unf + myfunc(bias.view(-1, 1))

    out = out_unf.view(batch_size, out_channels, out_h, out_w)
    return out

This approach however is terribly slower and (B) consumes a lot of memory, to the point that a “RuntimeError: CUDA out of memory” error is thrown. Is there an alternative to that approch which would be less memory greedy?

Any help is much appreciated!