I wanted to do 1D convolution on a circle(diffraction signal circle). So i customized a CircleConv layer and made a network using it. But the speed of the network is too slow to train on my dataset. The code is as follows:
import torch
import torch.nn as nn
import torch.nn.init as init
from collections import OrderedDict
import time
class CircleConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, stride=1):
super(CircleConv, self).__init__()
self.kernel_size = kernel_size
self.kernel_size_number = kernel_size * 1
self.out_channels = out_channels
self.dilation = (dilation, dilation)
self.in_channels = in_channels
self.weights = nn.Parameter(torch.Tensor(self.out_channels, self.in_channels, self.kernel_size_number))
init.xavier_uniform_(self.weights)
def forward(self, x):
"""x shape: (batch, channels, l)"""
windows = self.calculateWindows(x)
result = torch.zeros(
[x.shape[0] * self.out_channels, x.shape[2]], dtype=torch.float,
device=x.device)
for channel in range(x.shape[1]):
for i_convNumber in range(self.out_channels):
xx = torch.matmul(windows[channel], self.weights[i_convNumber][channel])
xx = xx.view(-1, x.shape[2])
result[i_convNumber * xx.shape[0]: (i_convNumber + 1) * xx.shape[0]] += xx
result = result.view(self.out_channels, x.shape[0], x.shape[2]).permute(1, 0, 2)
return result
def calculateWindows(self, x):
length = x.shape[-1]
half = self.kernel_size // 2
triple = torch.cat(tensors=[x[:, :, -half:], x, x[:, :, :half]], dim=-1)
lt = []
for step in range(length):
window = triple[:, :, half + step - half:half + step + half + 1]
window = window.reshape(x.shape[0], -1)
lt.append(window)
windows = torch.stack(lt, dim=2)
windows = windows.transpose(1, 2).contiguous().view(-1, x.shape[1], self.kernel_size_number)
windows = windows.transpose(0, 1)
return windows
class CircleConvCNN(nn.Module):
def __init__(self, in_channels, output_features):
super(CircleConvCNN, self).__init__()
self.net1 = nn.Sequential(CircleConv(in_channels, 32), nn.GELU(), CircleConv(32, 32), nn.GELU(), CircleConv(32, 64), nn.GELU(), CircleConv(64, 64), nn.GELU(), CircleConv(64, 128), nn.GELU(), CircleConv(128, 128), nn.GELU())
self.global_pool = nn.AdaptiveAvgPool1d(1)
self.net2 = nn.Sequential(OrderedDict([('linear1', nn.Linear(128, 1024)), ('activation1', nn.GELU()), ('linear2', nn.Linear(1024, 512)), ('activation2', nn.GELU()), ('linear3', nn.Linear(512, 9))]))
def forward(self, x):
x = x.view(x.shape[0], x.shape[1], 1)
x = x.permute(0, 2, 1)
x = self.net1(x)
x = self.global_pool(x)
# x = self.flatten(x)
x = x.view(x.size()[0], -1)
# print(x.size())
output = self.net2(x)
return output
# Below is the test code
CNN = CircleConvCNN(1, 9)
CNN.to('cuda')
for i in range(10):
print('time tick:')
ticks = time.time()
x = torch.randn(800, 147).to('cuda')
out = CNN(x)
print('forward elapsed time:{}'.format(time.time() - ticks))
print('time tick:')
ticks = time.time()
out.mean().backward()
print('backward elapsed time:{}'.format(time.time() - ticks))
The output of above code is:
time tick:
forward elapsed time:2.740222454071045
time tick:
backward elapsed time:85.08087038993835
time tick:
forward elapsed time:2.3308119773864746
time tick:
backward elapsed time:84.98330068588257
time tick:
forward elapsed time:2.3467724323272705
time tick:
......
Why the backpropogation is so slow? Maybe some operations are slow, but i could not find them. Can someone give some suggestions?