Hello, I’m a student studying low power memory circuits / in-memory computing, and I’m trying to bring multiplication-accumulation(MAC) operation into SRAM array so most/all of convolution can be done inside memory array.
Since computation inside SRAM cannot handle full precision operation like GPU, i’m trying to cut the MAC results’ bit precision to some point without suffering severe accuracy loss.
So I’m trying to verify it with pytorch platform.
However, conv2d does not seem to support such function so i’m working on custom-building the convolution from the ground.
I’m using customised-mini-vgg : 128C3-128C3-MP2-256C3-256C3-MP2-512C3-512C3-MP2-1024FC-1024FC-10FC
I’m still working on code, but I’m stuck with error saying “CUDA out of memory”.
I get why I’m getting the error but can’t come up with solution.
Is there any way I can optimise the code without changing the batch size or network?
Also, if anyone knows of a custom 2d convolution that I can refer to I would apperciate it very much.
Thank you for reading my question!
(Please forgive that my writing and code may be hard to read. :’( I’m completely new to python/pytorch and knows only basics of neural network.)
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torch.autograd import Function
import pdb
class Conv2DFunctionCustom(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
ctx.save_for_backward(input, weight, bias)
ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
zeropad = nn.ZeroPad2d(ctx.padding[0])
batch_size = len(input)
input_channel = len(input[0])
input_size = len(input[0][0])
ochannel = len(weight)
ichannel = len(weight[0])
kernel_size = len(weight[0][0])
if padding:
inp_pad = zeropad(input)
else:
inp_pad = input
input_un_tensor = inp_pad.unfold(1,input_channel,input_channel)
input_reshape = input_un_tensor.transpose(1,4).reshape(len(inp_pad),len(inp_pad[0]),len(inp_pad[0][0]),len(inp_pad[0][0]))
weight_un_tensor = weight.unfold(1,ichannel,ichannel)
weight_reshape = weight_un_tensor.transpose(1,4).reshape(len(weight),len(weight[0]),len(weight[0][0]),len(weight[0][0]))
input_unfold = torch.nn.functional.unfold(input_reshape, (kernel_size, kernel_size))
print('unfoldinput:',input_unfold.size())
mult = input_unfold.transpose(1, 2).unfold(2,16,16)[None,:] * weight_reshape.view(weight_reshape.size(0), -1).unfold(1,16,16)[:,None,None]
psum = torch.sum(mult, dim=4)
psum = torch.sum(psum, dim=3)
out = torch.nn.functional.fold(psum, (input_size, input_size), (1, 1)).transpose(0,1)
return out