Here is a code snippet:
import torch
import torch.nn as nn
from utility import *
import torch.nn.functional as F
import numpy as np
import math
class MoE_model(nn.Module):
def init(self,device=‘cuda’,in_dim = 128, out_dim = 256,T=9,num_mod = 16,mod_dim = 32,mod_out_dim = 8):
super(MoE_model,self).init()
self.num_mod = num_mod
self.mod_dim = mod_dim
self.mod_out_dim = mod_out_dim
self.device = device
self.in_dim = in_dim
self.out_dim = out_dim
self.T = T
self.mod_layer_1 = nn.Linear(self.in_dim,self.mod_dim*self.num_mod)
self.mod_layer_1_bn = nn.BatchNorm1d(self.mod_dim*self.num_mod)
self.module_net = nn.ModuleList()
for i in range(self.num_mod):
mod = nn.Sequential(
nn.Linear(self.mod_dim,48),
nn.BatchNorm1d(48),
nn.ReLU(True),
nn.Linear(48,self.mod_out_dim),
nn.BatchNorm1d(self.mod_out_dim),
nn.ReLU(True)
)
self.module_net.append(mod)
self.rel_local_fc_1 = nn.Linear(self.num_mod*self.mod_out_dim*2,self.out_dim)
self.rel_local_fc_1_bn = nn.BatchNorm1d(256)
def forward(self,fl_02,fl_12):
fm_02 = F.relu(self.mod_layer_1_bn(self.mod_layer_1(fl_02.view(-1,self.in_dim))))
fm_12 = F.relu(self.mod_layer_1_bn(self.mod_layer_1(fl_12.view(-1,self.in_dim))))
fm_02_split = torch.split(fm_02.view(-1,self.num_mod,self.mod_dim),1,1)
fm_12_split = torch.split(fm_12.view(-1,self.num_mod,self.mod_dim),1,1)
fm_02_list = []
fm_12_list = []
for i,l in enumerate(self.module_net):
fm_02_list.append(l(fm_02_split[i].squeeze()))
fm_12_list.append(l(fm_12_split[i].squeeze()))
fm_02 = torch.cat(fm_02_list,-1)
fm_12 = torch.cat(fm_12_list,-1)
fm_02_sum = torch.sum(fm_02.view(-1,self.T,self.num_mod*self.mod_out_dim),1)
fm_12_sum = torch.sum(fm_12.view(-1,self.T,self.num_mod*self.mod_out_dim),1)
fm_cat = torch.cat([fm_02_sum,fm_12_sum],1)
fl = F.relu(self.rel_local_fc_1_bn(self.rel_local_fc_1(fm_cat)))
return fl