Autoencoder with multiple decoders

Dear Team, I am trying to create an auto-encoder model which has one encoder and multiple decoders(depending on the number of classes). I am stuck while optimizing the loss. The loss is calculated by adding all the losses of the decoder output and the original data matrix for each class. I am not getting how to get it combinedly for all the classes together. Some insights will be helpful. Thanks

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dataset import Dataset
from autoencoder import Autoencoder
import numpy as np

device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

num_epochs = 500
batch_size = 50
learning_rate = 0.01

model = Autoencoder().to(device)
criterion = nn.MSELoss()
class MultitaskAutoencoder(nn.Module):
def init(self):

super(MultitaskAutoencoder,self).init()
self.ec = nn.Linear(29,50)
self.dc1= nn.Linear(50,29)
self.dc2= nn.Linear(50,29)
self.dc3 = nn.Linear(50,29)
self.dc4= nn.Linear(50,29)
self.dc5= nn.Linear(50,29)

Multi_AE

def forward(self, x, dom=0):
    x = self.ec(x)
    out1 = self.dc1(x)
   # print ("out1:", out1.shape)

    out2 = self.dc2(x)
    #print ("out2:", out2.shape)

    out3 =self.dc3(x)
    #print ("out3:", out3.shape)
    #out4 = self.dc4(x)
    # out5 = self.dc5(x)
    final_out = torch.vstack((out1, out2, out3))

    return final_out

multitaskAE = MultitaskAutoencoder().to(device)

print(multitaskAE)

optimizer = torch.optim.Adam(multitaskAE.parameters(), lr=learning_rate, weight_decay=1e-5)

test_data = torch.randn(3, 1, 256).to(device)

out = multitaskAE(test_data)

print(out.shape)

def train_AE():
dataset = Dataset()
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(100):
    train_loss = 0.0

    for input, out in dataloader:
        input = input.float().to(device)

        #print('input.shape = ', input.shape)  # 128 * 256
        dom_out = out.float().to(device) ## original data # 128 * 256
        #print('dom_out.shape = ', dom_out.shape) # 128 * 256

        output = multitaskAE(input)
        #print('predicted output.shape = ', output.shape) # 128 * 256 -- reconstructed values

        out1 = output[0:50,:]
        out2 = output[50:100,:]
        out3 = output[100:150,:]


        loss1 = criterion(out1, dom_out) 
        loss2 = criterion(out2,dom_out) # wrong 
        loss3 = criterion(out3,dom_out)

        final_loss = loss1 + loss2 + loss3

        final_loss.backward()
        optimizer.zero_grad()

        train_loss += final_loss.item()

    print('===> Loss: ', train_loss / len(dataloader))

train_AE()

Double post from here.