My network is not sequential. It is a list of a nn.sequential of a same structure. When I code the forward function, I used a loop, which I feel is not efficient. How could I paralleled train the list of nn.sequential rather than train each nn.sequential one by one? Appreciate any help.
class my_net(nn.Module):
def __init__(self,num_ring):
super(my_net, self).__init__()
self.num_ring = num_ring
self.my_list = nn.ModuleList()
for i in range(num_ring):
self.my_list.append(nn.Sequential(
nn.Conv1d(64, 128, 1),
nn.BatchNorm1d(128),
))
def forward(self,x,ring):
# ring [B,N]
#x.shape is [B,N,D]
x_new = []
for i in range(self.num_ring):
#first found the corresponding xi to i
idx = (ring==i)
xi = x[idx].unsqueeze(0)
xi = xi.transpose(2,1)
# Then train on corresponding nn.Sequential
#xi[1,D,M_i]->xi[1,128,M_i]
xi = self.pooler_list[i](xi)
x_new.append(xi)
x_new = torch.cat((x_new),dim=2)
return x_new
Note each xi has different dimension. M_i maybe different