Hi there all,

I have written a convolution function using pure Pytorch in the hopes of using it for N-Dimensional convolutions.

However, the weights in my function do not seem to be updating.

Can anyone help me?

``````# coding=utf-8
import math
import numpy as np
import warnings
import torch
from torch.nn.parameter import Parameter
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch._jit_internal import List, Optional
import operator
from functools import partial

def flip(x, dim):
dim = x.dim() + dim if dim < 0 else dim
return x[tuple(slice(None, None) if i != dim
else torch.arange(x.size(i)-1, -1, -1).long()
for i in range(x.dim()))]

a = array_shape
temp=1
steps=torch.zeros(len(a))

for i, entry in enumerate(flip(p, 0)):
if(i==len(p)-1):
steps[i]=1
else:
temp=entry*temp
steps[i]=temp

steps=torch.roll(steps, 1)
steps=flip(steps, 0)
ones=torch.ones(len(a))
ones[-1]=0
temp_array =  a-ones
out=torch.matmul(steps,temp_array)
length = torch.sum(out)
return length

def ComplexMult(a, b):
op = partial(torch.einsum, "bct,dct->bdt")
op(a[..., 0], b[..., 0]) + op(a[..., 1], b[..., 1]),
op(a[..., 1], b[..., 0]) - op(a[..., 0], b[..., 1])
],
dim=-1)

def HelixConvolve(input, weight, bias, ndim):
input_shape = torch.Tensor(list(input.shape)).int()
weight_shape = torch.Tensor(list(weight.shape)).int()

return_shape_input = torch.zeros(ndim+2)
return_shape_input[:2] = input_shape[:2]
return_shape_input[2:] = input_shape[2:] + weight_shape[2:]-1
return_shape_input = return_shape_input.int()

return_shape_weight = torch.zeros(ndim+2)
return_shape_weight[:2] = weight_shape[:2]
return_shape_weight[2:] = input_shape[2:] + weight_shape[2:]-1
return_shape_weight = return_shape_weight.int()

conv_shape = torch.zeros(ndim+2)
conv_shape[0] = input_shape[0]
conv_shape[1] = weight_shape[0]
conv_shape[2:] = input_shape[2:] + weight_shape[2:] - 1
conv_shape = conv_shape.int()

conv_valid_shape = torch.zeros(ndim+2)
conv_valid_shape[0] = input_shape[0]
conv_valid_shape[1] = weight_shape[0]
conv_valid_shape[2:] = input_shape[2:] - weight_shape[2:] + 1
conv_valid_shape = conv_valid_shape.int()

start = tuple(torch.zeros(ndim+2).int())
slices_i = tuple(map(slice, start, end_i))
slices_w = tuple(map(slice, start, end_w))

len_i = GetLength(input_shape[2:], input_shape[2:] + weight_shape[2:]-1)
len_w = GetLength(weight_shape[2:], input_shape[2:] + weight_shape[2:]-1)
len_full = len_i.int()+len_w.int()-1

conv = torch.irfft(conv, 1, onesided=False)
conv = conv[:,:,0:len_full.item()]
#         if(bias is not None):
#             conv+=bias

conv = conv.reshape(list(conv_shape))

slices_c = tuple(map(slice, start, end_c))
conv = conv[slices_c]

return conv.float()

class HelixConv(Module):

def __init__(self, in_channels, out_channels, kernel_size, ndim=2, stride=1,
super(HelixConv, self).__init__()
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.ndim = ndim
self.stride=stride
self.groups=groups
self.dilation = dilation

weight_shape = torch.zeros(self.ndim+2)
weight_shape[0] = out_channels
weight_shape[1] = in_channels
weight_shape[2:] = torch.Tensor(kernel_size)
weight_shape = weight_shape.int()

if bias:
self.bias = torch.Tensor(torch.zeros(out_channels, 1))
else:
self.bias = None

self.weight = torch.Tensor(torch.randn(list(weight_shape)))

def forward(self, input):
return HelixConvolve(input,
self.weight,
self.bias,
self.ndim
)
``````

Hi,

I think you forgot to use the `nn.Parameter()` when you define your weights/bias.
It should be `self.weight = nn.Parameter(torch.randn(list(weight_shape)))`