I’ve been trying to build a custom channel-wise convolution operation with multiplication and sum operation so that I can alter the bit-precision during the convolution process.
(I’ve checked several forum articles and all of them told me I should build it with unfold and fold functions.)
But when I replace the torch.nn.conv2d function with my code
There is no error but accuracy seems to degrade a lot
I’m completely new to python and pytorch, and only know basics about neural network so I can’t find what mistakes I made.
Can anyone please help? I would really appreciate any help.
Also, if there is any tips to save gpu usage I would also appreciate it very much, since my code seems to consume tons of memory so I had to cut the batch size and it’s taking forever to see the results.
class Conv2DFunctionCustom(Function): @staticmethod def forward(ctx, input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): ctx.save_for_backward(input, weight, bias) ctx.stride, ctx.padding, ctx.dilation, ctx.groups = stride, padding, dilation, groups zeropad = nn.ZeroPad2d(ctx.padding) batch_size = len(input) input_channel = len(input) input_size = len(input) ochannel = len(weight) ichannel = len(weight) kernel_size = len(weight) if padding: inp_pad = zeropad(input) else: inp_pad = input input_unfold = torch.nn.functional.unfold(inp_pad, kernel_size).transpose(1, 2) if input_channel < 256: input_unfold = input_unfold.unfold(2,input_channel,input_channel)[None,:] else: input_unfold = input_unfold.unfold(2,256,256)[None,:] weight_unfold = weight.view(weight.size(0), -1) if input_channel < 256: weight_unfold = weight_unfold.unfold(1,ichannel,ichannel)[:,None,None] else: weight_unfold = weight_unfold.unfold(1,256,256)[:,None,None] mac = input_unfold * weight_unfold mac = torch.sum(mac, dim=4) mac = torch.sum(mac, dim=3) out = torch.nn.functional.fold(mac, (input_size, input_size), (1, 1)).transpose(0,1) return out def backward(ctx, grad_output): input, weight, bias = ctx.saved_tensors stride, padding, dilation, groups = ctx.stride, ctx.padding, ctx.dilation, ctx.groups grad_input = grad_weight = grad_weight = grad_input = grad_bias = grad_stride = grad_padding = grad_dilation = grad_groups = None if ctx.needs_input_grad: grad_input = torch.nn.grad.conv2d_input(input.shape, weight, grad_output, stride, padding, dilation, groups) if ctx.needs_input_grad: grad_weight = torch.nn.grad.conv2d_weight(input, weight.shape, grad_output, stride, padding, dilation, groups) return grad_input, grad_weight, grad_input, grad_weight, grad_bias, grad_stride, grad_padding, grad_dilation, grad_groups