Hello
I’ve been trying to build a custom channel-wise convolution operation with multiplication and sum operation so that I can alter the bit-precision during the convolution process.
(I’ve checked several forum articles and all of them told me I should build it with unfold and fold functions.)
But when I replace the torch.nn.conv2d function with my code
There is no error but accuracy seems to degrade a lot
I’m completely new to python and pytorch, and only know basics about neural network so I can’t find what mistakes I made.
Can anyone please help? I would really appreciate any help.
Also, if there is any tips to save gpu usage I would also appreciate it very much, since my code seems to consume tons of memory so I had to cut the batch size and it’s taking forever to see the results.
class Conv2DFunctionCustom(Function):
@staticmethod
def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
ctx.save_for_backward(input, weight, bias)
ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups
zeropad = nn.ZeroPad2d(ctx.padding[0])
batch_size = len(input)
input_channel = len(input[0])
input_size = len(input[0][0])
ochannel = len(weight)
ichannel = len(weight[0])
kernel_size = len(weight[0][0])
if padding:
inp_pad = zeropad(input)
else:
inp_pad = input
input_unfold = torch.nn.functional.unfold(inp_pad, kernel_size).transpose(1, 2)
if input_channel < 256:
input_unfold = input_unfold.unfold(2,input_channel,input_channel)[None,:]
else:
input_unfold = input_unfold.unfold(2,256,256)[None,:]
weight_unfold = weight.view(weight.size(0), -1)
if input_channel < 256:
weight_unfold = weight_unfold.unfold(1,ichannel,ichannel)[:,None,None]
else:
weight_unfold = weight_unfold.unfold(1,256,256)[:,None,None]
mac = input_unfold * weight_unfold
mac = torch.sum(mac, dim=4)
mac = torch.sum(mac, dim=3)
out = torch.nn.functional.fold(mac, (input_size, input_size), (1, 1)).transpose(0,1)
return out
def backward(ctx, grad_output):
input, weight, bias = ctx.saved_tensors
stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups
grad_input = grad_weight = grad_weight = grad_input = grad_bias = grad_stride = grad_padding = grad_dilation = grad_groups = None
if ctx.needs_input_grad[1]:
grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups)
if ctx.needs_input_grad[2]:
grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups)
return grad_input, grad_weight, grad_input, grad_weight, grad_bias, grad_stride, grad_padding, grad_dilation, grad_groups