Hi there all,
I have written a convolution function using pure Pytorch in the hopes of using it for N-Dimensional convolutions.
However, the weights in my function do not seem to be updating.
Can anyone help me?
# coding=utf-8
import math
import numpy as np
import warnings
import torch
from torch.nn.parameter import Parameter
import torch.nn.init as init
from torch.nn.modules.module import Module
from torch.autograd import Variable
from torch._jit_internal import List, Optional
import operator
from functools import partial
def flip(x, dim):
dim = x.dim() + dim if dim < 0 else dim
return x[tuple(slice(None, None) if i != dim
else torch.arange(x.size(i)-1, -1, -1).long()
for i in range(x.dim()))]
def GetLength(array_shape, padded_shape):
a = array_shape
p = padded_shape
temp=1
steps=torch.zeros(len(a))
for i, entry in enumerate(flip(p, 0)):
if(i==len(p)-1):
steps[i]=1
else:
temp=entry*temp
steps[i]=temp
steps=torch.roll(steps, 1)
steps=flip(steps, 0)
ones=torch.ones(len(a))
ones[-1]=0
temp_array = a-ones
out=torch.matmul(steps,temp_array)
length = torch.sum(out)
return length
def ComplexMult(a, b):
op = partial(torch.einsum, "bct,dct->bdt")
return torch.stack([
op(a[..., 0], b[..., 0]) + op(a[..., 1], b[..., 1]),
op(a[..., 1], b[..., 0]) - op(a[..., 0], b[..., 1])
],
dim=-1)
def HelixConvolve(input, weight, bias, ndim):
input_shape = torch.Tensor(list(input.shape)).int()
weight_shape = torch.Tensor(list(weight.shape)).int()
return_shape_input = torch.zeros(ndim+2)
return_shape_input[:2] = input_shape[:2]
return_shape_input[2:] = input_shape[2:] + weight_shape[2:]-1
return_shape_input = return_shape_input.int()
return_shape_weight = torch.zeros(ndim+2)
return_shape_weight[:2] = weight_shape[:2]
return_shape_weight[2:] = input_shape[2:] + weight_shape[2:]-1
return_shape_weight = return_shape_weight.int()
conv_shape = torch.zeros(ndim+2)
conv_shape[0] = input_shape[0]
conv_shape[1] = weight_shape[0]
conv_shape[2:] = input_shape[2:] + weight_shape[2:] - 1
conv_shape = conv_shape.int()
conv_valid_shape = torch.zeros(ndim+2)
conv_valid_shape[0] = input_shape[0]
conv_valid_shape[1] = weight_shape[0]
conv_valid_shape[2:] = input_shape[2:] - weight_shape[2:] + 1
conv_valid_shape = conv_valid_shape.int()
input_pad = Variable(torch.zeros(list(return_shape_input))).cuda()
weight_pad = Variable(torch.zeros(list(return_shape_weight))).cuda()
start = tuple(torch.zeros(ndim+2).int())
end_i = tuple(map(operator.add, start, input_shape))
end_w = tuple(map(operator.add, start, weight_shape))
slices_i = tuple(map(slice, start, end_i))
slices_w = tuple(map(slice, start, end_w))
input_pad[slices_i] = input
weight_pad[slices_w] = weight
len_i = GetLength(input_shape[2:], input_shape[2:] + weight_shape[2:]-1)
len_w = GetLength(weight_shape[2:], input_shape[2:] + weight_shape[2:]-1)
len_full = len_i.int()+len_w.int()-1
input_pad = input_pad.flatten(start_dim=2)
weight_pad = weight_pad.flatten(start_dim=2)
input_pad = torch.rfft(input_pad, 1, onesided=False)
weight_pad = torch.rfft(weight_pad, 1, onesided=False)
conv = ComplexMult(input_pad, weight_pad)
conv = torch.irfft(conv, 1, onesided=False)
conv = conv[:,:,0:len_full.item()]
# if(bias is not None):
# conv+=bias
conv = conv.reshape(list(conv_shape))
end_c = tuple(map(operator.add, start, conv_valid_shape))
slices_c = tuple(map(slice, start, end_c))
conv = conv[slices_c]
return conv.float()
class HelixConv(Module):
def __init__(self, in_channels, out_channels, kernel_size, ndim=2, stride=1,
padding=0, dilation=1, groups=1,
bias=True, padding_mode='zeros'):
super(HelixConv, self).__init__()
self.kernel_size = kernel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.ndim = ndim
self.stride=stride
self.padding = padding
self.padding_mode = padding_mode
self.groups=groups
self.dilation = dilation
weight_shape = torch.zeros(self.ndim+2)
weight_shape[0] = out_channels
weight_shape[1] = in_channels
weight_shape[2:] = torch.Tensor(kernel_size)
weight_shape = weight_shape.int()
if bias:
self.bias = torch.Tensor(torch.zeros(out_channels, 1))
else:
self.bias = None
self.weight = torch.Tensor(torch.randn(list(weight_shape)))
def forward(self, input):
return HelixConvolve(input,
self.weight,
self.bias,
self.ndim
)