I would like to use pytorch to create the MSFBCNN model in this paper.
I am currently making it this way.
import torch
import torch.nn as nn
class MSFBCNN(nn.Module):
def __init__(self,channel_num,hz,time):
super().__init__()
self.channel_num = channel_num
self.hz = hz
self.time = time
self.timeconv1 = nn.Conv2d(1,self.hz, kernel_size = (64,1), padding='same', bias=False, padding_mode='replicate')
self.timeconv2 = nn.Conv2d(1,self.hz, kernel_size = (40,1), padding='same', bias=False, padding_mode='replicate')
self.timeconv3 = nn.Conv2d(1,self.hz, kernel_size = (26,1), padding='same', bias=False, padding_mode='replicate')
self.timeconv4 = nn.Conv2d(1,self.hz, kernel_size = (16,1), padding='same', bias=False, padding_mode='replicate')
self.batchnorm = nn.BatchNorm2d(self.hz*4)
self.spacialconv = nn.Sequential(
nn.Conv2d(self.hz*4, self.hz, kernel_size = (1,self.channel_num), padding='valid'),
nn.BatchNorm2d(self.hz),
nn.ReLU(),
nn.AvgPool2d(kernel_size=(75,1),stride=(15,1)),
nn.Dropout2d(p = 0.5),
)
self.linear = nn.Linear(self.hz*self.hz//15,3)
def forward(self,x):
timeconv1_x = self.timeconv1(x)
timeconv2_x = self.timeconv2(x)
timeconv3_x = self.timeconv3(x)
timeconv4_x = self.timeconv4(x)
all_x = torch.cat((timeconv1_x, timeconv2_x, timeconv3_x, timeconv4_x), 1)
all_x = self.batchnorm(all_x)
all_x = self.spacialconv(all_x)
all_x = torch.flatten(all_x)
print(all_x.shape,self.hz//15)
all_x = self.linear(all_x)
return all_x
net = MSFBCNN(16, 250, 1)
x = torch.randn(1,1, 250, 16)
print(net(x).shape)
But is it correct as the structure of the model in this image?