Hello everyone! I’ve been working on a customized Conv2d layer which allows me to do operations on the data immediately after every multiplication and sum (I’m trying to prove that we can quantize a CNN to use less bits while keeping the accuracy, it’s for my masters). My starting point is the code available here, where unfold and fold is used for that.

The code is:

```
# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
batch_size, in_channels, in_h, in_w = input.shape
out_channels, in_channels, kh, kw = weight.shape
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
inp_unf = unfold(input)
if bias is None:
out_unf = inp_unf.transpose(1, 2).matmul(weight.view(weight.size(0), -1).t()).transpose(1, 2)
else:
out_unf = (inp_unf.transpose(1, 2).matmul(w_) + bias).transpose(1, 2)
out = out_unf.view(batch_size, out_channels, out_h, out_w)
return out
```

The problem is that I need to do call a function immediately after every multiplication and sum, so I could not use the `.matmul()`

method since it does both for me. My first attempt was to do the operations independently, calling `myfunc()`

after each one of them:

```
# Vanilla Convolution
def myconv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
batch_size, in_channels, in_h, in_w = input.shape
out_channels, in_channels, kh, kw = weight.shape
unfold = torch.nn.Unfold(kernel_size=(kh, kw), dilation=dilation, padding=padding, stride=stride)
inp_unf = unfold(input)
weight_t = weight.view(weight.size(0), -1).t()
inp_unf_t = inp_unf.transpose(1, 2)
inp_unf_t = inp_unf_t[:,:,:, None]
inp_unf_t_exp = inp_unf_t.transpose(2,3).repeat((1,1,weight_t.t().size(0),1))
weight_t_exp = weight_t.t().repeat((1,inp_unf_t.size(1),1,1))
out_unf = myfunc(inp_unf_t_exp.mul(weight_t_exp))
out_unf = myfunc(torch.sum(out_unf, dim=3).transpose(1, 2))
if bias is not None:
out_unf = out_unf + myfunc(bias.view(-1, 1))
out = out_unf.view(batch_size, out_channels, out_h, out_w)
return out
```

This approach however is **terribly slower** and (B) consumes a **lot of memory**, to the point that a “RuntimeError: CUDA out of memory” error is thrown. Is there an alternative to that approch which would be less memory greedy?

Any help is much appreciated!