Pytorch Autograd Not Working

Hi there all,

I have written a convolution function using pure Pytorch in the hopes of using it for N-Dimensional convolutions.

However, the weights in my function do not seem to be updating.

Can anyone help me?

# coding=utf-8
import math
import numpy as np
import warnings
import torch
from torch.nn.parameter import Parameter
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch.autograd import Variable
from torch._jit_internal import List, Optional
import operator
from functools import partial

def flip(x, dim):
    dim = x.dim() + dim if dim < 0 else dim
    return x[tuple(slice(None, None) if i != dim
             else torch.arange(x.size(i)-1, -1, -1).long()
             for i in range(x.dim()))]
    
def GetLength(array_shape, padded_shape):
        a = array_shape
        p = padded_shape
        temp=1
        steps=torch.zeros(len(a))
        
        for i, entry in enumerate(flip(p, 0)):
            if(i==len(p)-1):
                steps[i]=1
            else:
                temp=entry*temp
                steps[i]=temp
            

        steps=torch.roll(steps, 1)
        steps=flip(steps, 0)
        ones=torch.ones(len(a))
        ones[-1]=0
        temp_array =  a-ones
        out=torch.matmul(steps,temp_array)
        length = torch.sum(out)
        return length
        
 
def ComplexMult(a, b):
    op = partial(torch.einsum, "bct,dct->bdt")
    return torch.stack([
        op(a[..., 0], b[..., 0]) + op(a[..., 1], b[..., 1]),
        op(a[..., 1], b[..., 0]) - op(a[..., 0], b[..., 1])
    ],
                       dim=-1)

    
def HelixConvolve(input, weight, bias, ndim):
        input_shape = torch.Tensor(list(input.shape)).int()
        weight_shape = torch.Tensor(list(weight.shape)).int()
        
        return_shape_input = torch.zeros(ndim+2)
        return_shape_input[:2] = input_shape[:2]
        return_shape_input[2:] = input_shape[2:] + weight_shape[2:]-1
        return_shape_input = return_shape_input.int()
        
        return_shape_weight = torch.zeros(ndim+2)
        return_shape_weight[:2] = weight_shape[:2]
        return_shape_weight[2:] = input_shape[2:] + weight_shape[2:]-1
        return_shape_weight = return_shape_weight.int()
        
        conv_shape = torch.zeros(ndim+2)
        conv_shape[0] = input_shape[0]
        conv_shape[1] = weight_shape[0]
        conv_shape[2:] = input_shape[2:] + weight_shape[2:] - 1
        conv_shape = conv_shape.int()
        
        conv_valid_shape = torch.zeros(ndim+2)
        conv_valid_shape[0] = input_shape[0]
        conv_valid_shape[1] = weight_shape[0]
        conv_valid_shape[2:] = input_shape[2:] - weight_shape[2:] + 1
        conv_valid_shape = conv_valid_shape.int()
        
        input_pad = Variable(torch.zeros(list(return_shape_input))).cuda()
        weight_pad = Variable(torch.zeros(list(return_shape_weight))).cuda()
        
        start = tuple(torch.zeros(ndim+2).int())
        end_i = tuple(map(operator.add, start, input_shape))
        end_w = tuple(map(operator.add, start, weight_shape))
        slices_i = tuple(map(slice, start, end_i))
        slices_w = tuple(map(slice, start, end_w))
        
        input_pad[slices_i] = input
        weight_pad[slices_w] = weight
    
        len_i = GetLength(input_shape[2:], input_shape[2:] + weight_shape[2:]-1)
        len_w = GetLength(weight_shape[2:], input_shape[2:] + weight_shape[2:]-1)
        len_full = len_i.int()+len_w.int()-1
        
        input_pad = input_pad.flatten(start_dim=2)
        weight_pad = weight_pad.flatten(start_dim=2)
        
        
        input_pad = torch.rfft(input_pad, 1, onesided=False)
        weight_pad = torch.rfft(weight_pad, 1, onesided=False)
        conv = ComplexMult(input_pad, weight_pad)
        conv = torch.irfft(conv, 1, onesided=False)
        conv = conv[:,:,0:len_full.item()]
#         if(bias is not None):
#             conv+=bias
            
        conv = conv.reshape(list(conv_shape))

        end_c = tuple(map(operator.add, start, conv_valid_shape))
        slices_c = tuple(map(slice, start, end_c))
        conv = conv[slices_c]

        return conv.float()
    


    
class HelixConv(Module):
    
    def __init__(self, in_channels, out_channels, kernel_size, ndim=2, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros'):
        super(HelixConv, self).__init__()
        self.kernel_size = kernel_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.ndim = ndim
        self.stride=stride
        self.padding = padding
        self.padding_mode = padding_mode
        self.groups=groups
        self.dilation = dilation
        
        weight_shape = torch.zeros(self.ndim+2)
        weight_shape[0] = out_channels
        weight_shape[1] = in_channels
        weight_shape[2:] = torch.Tensor(kernel_size)
        weight_shape = weight_shape.int()
        
        if bias:
            self.bias = torch.Tensor(torch.zeros(out_channels, 1))
        else:
            self.bias = None
               
        self.weight = torch.Tensor(torch.randn(list(weight_shape)))
        
    
    def forward(self, input):
        return HelixConvolve(input, 
                                           self.weight, 
                                           self.bias, 
                                           self.ndim
                                          )

Hi,

I think you forgot to use the nn.Parameter() when you define your weights/bias.
It should be self.weight = nn.Parameter(torch.randn(list(weight_shape)))