Multi-task DNN only outputs labeled tasks at a specific batch

Hi all,
I’m working on a model for multi-task learning which has, say, 1000 tasks. nn.ModuleList() was used to wrap those tasks (heads) as shown in the below model. Assuming the batch size is 32, the output is a list of 1000 sublists each has 32 predicted values. One issue here is the label matrix is actually very sparse (>99% sparsity). May be only 10 out of those 1000 sublists actually have labels which can be used to calculate the loss. I have to mask the output of the rest of 990 tasks. It’s such a waste of computation resource to calculate predict those 990 tasks that don’t contribute to the loss function. How to find a new model structure that only calculate those dozen tasks with experimental labels and skip all of others. The labeled tasks are also different from one another batch.

class MultiTaskDNN(nn.Module):

    def __init__(self, n_tasks, 
                 input_dim=1024, 
                 output_dim=1, 
                 hidden_dim=[1024, 100], 
                 inits=['xavier_normal', 'kaiming_uniform'],
                 act_function=['relu', 'leaky_relu'], 
                 dropouts=[0.10, 0.25], 
                 batch_norm=True):
        #
        #from torch.nn.init import kaiming_uniform_
        #from torch.nn.init import xavier_uniform_, xavier_normal_
        # About the order of batch normalization and dropout
        
        super(MultiTaskDNN, self).__init__()
        self.n_tasks = n_tasks
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.act_function = act_function
        self.batch_norm = batch_norm
        current_dim = input_dim
        self.layers = nn.ModuleList()
        self.dropouts = nn.ModuleList()
        self.bns = nn.ModuleList()
        
        for k, hdim in enumerate(hidden_dim):
            self.layers.append(nn.Linear(current_dim, hdim))
            self.bns.append(nn.BatchNorm1d(hdim, eps=2e-1))
            
            current_dim = hdim
            
            if inits[k] == 'xavier_normal':
                nn.init.xavier_normal_(self.layers[k].weight)
            elif inits[k] == 'kaiming_normal':
                nn.init.kaiming_normal_(self.layers[k].weight)
            elif inits[k] == 'xavier_uniform':
                nn.init.xavier_uniform_(self.layers[k].weight)
            elif inits[k] == 'kaiming_uniform':
                nn.init.kaiming_uniform_(self.layers[k].weight)
                
            self.dropouts.append(nn.Dropout(dropouts[k]))
        # n_targets
        self.heads = nn.ModuleList()
        for _ in range(self.n_tasks):
            self.heads.append(nn.Linear(current_dim, output_dim))


    def forward(self, x):
        
        for k, layer in enumerate(self.layers):
            x = layer(x)
            if self.act_function[k] == 'sigmoid':
                x = torch.sigmoid(x)
            elif self.act_function[k] == 'relu':
                x = F.relu(x)
            elif self.act_function[k] == 'leaky_relu':
                x = F.leaky_relu(x)
            
            if self.batch_norm == True:
                x = self.bns[k](x)
                
            x = self.dropouts[k](x)
        
        outputs = []
        for head in self.heads:
            outputs.append(head(x)) 

#Train the model

for ep in range(N_EPOCHS):
    b_loss = []
    for i, (fps, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(fps)
        loss = torch.tensor(0.0).to(device)
        for j, w in enumerate(weights):
            # mask keeping labeled molecules for each task
            mask = labels[:, j] != -100.0
            if len(labels[:, j][mask]) > 0:
                # the loss is the sum of each task/target loss.
                # there are labeled samples for this task, so we add it's loss
                loss += criterion(outputs[j][mask], labels[:, j][mask].view(-1, 1)) * w
        
        loss.backward()
        optimizer.step()

You could pass the target tensor into the forward method and use it to skip the actual calculation of head in the last for loop, if this output is not needed (i.e. if no target is available for it).
During validation, you would have to make sure that the target tensor is ignored and I assume you want to calculate the outputs of all self.heads.

Thanks for your suggestion. I passed column index to the forward method. It worked well and sped up a lot.

    def forward(self, x, col_idx):
        
        for k, layer in enumerate(self.layers):
            x = layer(x)
            if self.act_function[k] == 'sigmoid':
                x = torch.sigmoid(x)
            elif self.act_function[k] == 'relu':
                x = F.relu(x)
            elif self.act_function[k] == 'leaky_relu':
                x = F.leaky_relu(x)
            
            if self.batch_norm == True:
                x = self.bns[k](x)
                
            x = self.dropouts[k](x)
        
        outputs = torch.zeros([fps.shape[0], len(col_idx)])
        
        for k, idx in enumerate(col_idx):
            outputs[:, k] = torch.squeeze(self.heads[idx](x), 1)
        """
        outputs = []
        for i, head in enumerate(self.heads):
            if i in col_idx:
                outputs.append(head(x))
        """
        #outputs = torch.cat(outputs).view(len(outputs), -1).T
        return outputs