My network is not sequential. It is a list of a nn.sequential of a same structure. When I code the forward function, I used a loop, which I feel is not efficient. How could I paralleled train the list of nn.sequential rather than train each nn.sequential one by one? Appreciate any help.
class my_net(nn.Module): def __init__(self,num_ring): super(my_net, self).__init__() self.num_ring = num_ring self.my_list = nn.ModuleList() for i in range(num_ring): self.my_list.append(nn.Sequential( nn.Conv1d(64, 128, 1), nn.BatchNorm1d(128), )) def forward(self,x,ring): # ring [B,N] #x.shape is [B,N,D] x_new =  for i in range(self.num_ring): #first found the corresponding xi to i idx = (ring==i) xi = x[idx].unsqueeze(0) xi = xi.transpose(2,1) # Then train on corresponding nn.Sequential #xi[1,D,M_i]->xi[1,128,M_i] xi = self.pooler_list[i](xi) x_new.append(xi) x_new = torch.cat((x_new),dim=2) return x_new
Note each xi has different dimension. M_i maybe different