Optimising memory allocation in custom 2d-convolution

Hello, I’m a student studying low power memory circuits / in-memory computing, and I’m trying to bring multiplication-accumulation(MAC) operation into SRAM array so most/all of convolution can be done inside memory array.
Since computation inside SRAM cannot handle full precision operation like GPU, i’m trying to cut the MAC results’ bit precision to some point without suffering severe accuracy loss.
So I’m trying to verify it with pytorch platform.
However, conv2d does not seem to support such function so i’m working on custom-building the convolution from the ground.
I’m using customised-mini-vgg : 128C3-128C3-MP2-256C3-256C3-MP2-512C3-512C3-MP2-1024FC-1024FC-10FC
I’m still working on code, but I’m stuck with error saying “CUDA out of memory”.
I get why I’m getting the error but can’t come up with solution.
Is there any way I can optimise the code without changing the batch size or network?
Also, if anyone knows of a custom 2d convolution that I can refer to I would apperciate it very much.
Thank you for reading my question!
(Please forgive that my writing and code may be hard to read. :’( I’m completely new to python/pytorch and knows only basics of neural network.)

import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch.autograd import Function
import pdb
 
class Conv2DFunctionCustom(Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
        
        ctx.save_for_backward(input, weight, bias)
        ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups

        zeropad = nn.ZeroPad2d(ctx.padding[0])
        batch_size = len(input)
        input_channel = len(input[0])
        input_size = len(input[0][0])
        ochannel = len(weight)
        ichannel = len(weight[0])
        kernel_size = len(weight[0][0])
        
  
        if padding:
            inp_pad = zeropad(input)
        else:
            inp_pad = input

        input_un_tensor = inp_pad.unfold(1,input_channel,input_channel)

        input_reshape = input_un_tensor.transpose(1,4).reshape(len(inp_pad),len(inp_pad[0]),len(inp_pad[0][0]),len(inp_pad[0][0]))
        
        weight_un_tensor = weight.unfold(1,ichannel,ichannel)
        weight_reshape = weight_un_tensor.transpose(1,4).reshape(len(weight),len(weight[0]),len(weight[0][0]),len(weight[0][0]))

        input_unfold = torch.nn.functional.unfold(input_reshape, (kernel_size, kernel_size))
        print('unfoldinput:',input_unfold.size())


        mult = input_unfold.transpose(1, 2).unfold(2,16,16)[None,:] * weight_reshape.view(weight_reshape.size(0), -1).unfold(1,16,16)[:,None,None]
        psum = torch.sum(mult, dim=4)
        psum = torch.sum(psum, dim=3)
               
        out = torch.nn.functional.fold(psum, (input_size, input_size), (1, 1)).transpose(0,1)

        return out

Hi,

This post was trying to do the same thing you do I think and the same answer applies to your code if I’m not mistaken: Make Custom Conv2d Layer efficient (wrt speed and memory)

Thank you, article you gave me helped a lot. :slight_smile: