Implementing Mixture of Expert layer

I am trying to implement the a mixture of expert layer, similar to the one described in:

Basically this layer have a number of sub-layers F_i(x_i) which process a projected version of the input. There is also a gating layer G_i(x_i) which is basically an attention mechanism over all sub-expert-layers:

sum(G_i(x_i)*F_i(x_i).

My Naive approach is to build a list for the sub-layers:

sublayer_list = nn.ModuleList()
for i in range(num_of_layer):
sublayer_list.append(self.make_layer())

Then when applying this I use another for loop

out_list= []
for i,l in enumerate(sublayer_list):
out_list.appned(l(input(i)))

However the addition of this Mixture-of-Expert layer slows training by almost 7 times (against one with MoE layer swapped for a similar-sized MLP). I am wondering if there are more efficient ways to implement this in pytorch? Many thanks!

Hello, we’re developing a library for sparse training in Pytorch. Please provide us with your complete pytorch code, and we’ll optimize and include it in our library.

I re-implemented the Sparsely-Gated Mixture-of-Experts Layer based on the tensorflow code here. You can find my implementation here: https://github.com/davidmrau/mixture-of-experts

Here is a code snippet:

import torch
import torch.nn as nn
from utility import *
import torch.nn.functional as F
import numpy as np
import math

class MoE_model(nn.Module):
def init(self,device=‘cuda’,in_dim = 128, out_dim = 256,T=9,num_mod = 16,mod_dim = 32,mod_out_dim = 8):
super(MoE_model,self).init()

    self.num_mod = num_mod
    self.mod_dim = mod_dim
    self.mod_out_dim = mod_out_dim
    self.device = device
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.T = T
    self.mod_layer_1 = nn.Linear(self.in_dim,self.mod_dim*self.num_mod)
    self.mod_layer_1_bn = nn.BatchNorm1d(self.mod_dim*self.num_mod)


    self.module_net = nn.ModuleList()

    for i in range(self.num_mod):
        mod = nn.Sequential(
                  nn.Linear(self.mod_dim,48),
                  nn.BatchNorm1d(48),
                  nn.ReLU(True),
             
                  nn.Linear(48,self.mod_out_dim),
                  nn.BatchNorm1d(self.mod_out_dim),
                  nn.ReLU(True)
                  )
        self.module_net.append(mod)

    self.rel_local_fc_1 = nn.Linear(self.num_mod*self.mod_out_dim*2,self.out_dim)
    self.rel_local_fc_1_bn = nn.BatchNorm1d(256)


def forward(self,fl_02,fl_12):
    fm_02 = F.relu(self.mod_layer_1_bn(self.mod_layer_1(fl_02.view(-1,self.in_dim))))
    fm_12 = F.relu(self.mod_layer_1_bn(self.mod_layer_1(fl_12.view(-1,self.in_dim))))
    fm_02_split = torch.split(fm_02.view(-1,self.num_mod,self.mod_dim),1,1)
    fm_12_split = torch.split(fm_12.view(-1,self.num_mod,self.mod_dim),1,1)

    fm_02_list = []
    fm_12_list = []
    for i,l in enumerate(self.module_net):

        fm_02_list.append(l(fm_02_split[i].squeeze()))
        fm_12_list.append(l(fm_12_split[i].squeeze()))
    fm_02 = torch.cat(fm_02_list,-1)
    fm_12 = torch.cat(fm_12_list,-1)

    fm_02_sum = torch.sum(fm_02.view(-1,self.T,self.num_mod*self.mod_out_dim),1)
    fm_12_sum = torch.sum(fm_12.view(-1,self.T,self.num_mod*self.mod_out_dim),1)
    fm_cat = torch.cat([fm_02_sum,fm_12_sum],1)
    fl = F.relu(self.rel_local_fc_1_bn(self.rel_local_fc_1(fm_cat)))
    return fl

Thanks I will have a look

Have you tested the speed of running this on GPUs?

It looks like for this model, all the weights are used in each forward pass? Or is my interpretation off?

yes for this simple model there is no gating mechanism yet. The output from each expert layer is concatenated at the end. My concern is that currently this model is very slow to train. Adding the gating mechanism will even make it slower. The training speed is far slower than a MLP model with similar number of parameters, which is kind of wierd since the number of FLOPS should be roughly the same.

I did not yet implement the distribution over multiple GPUs. On a single GPU though, the speed up will be significantly because of the architecture of the sparsely-gated MoE. Instead of passing each sample through all experts, the gating mechanism will make sure that each sample is only passed through k experts.