puts[idx]))
def forward(self, inputs):
out =
for idx, enc in enumerate(self.encoders):
out.append(enc(in
z = torch.cat(out, dim=1) #z = self.encoder(out) #out = self.decoder(z) z = torch.split(z, 3, dim=1) outs = [] for idx, dec in enumerate(self.decoders): outs.append(dec(out[idx])) return outs