Branched model for handling optional inputs

About half of the elements in the dataset have certain features, which need to go through a separate layer first. These are not missing features, but they simply don’t exist for all elements. Everything is part of the same problem/objective, therefore a separate model should not be used. My best guess for a solution is the following:

  • Define two forward functions (two branches with shared modules), where the output of the optional branch is replaced with a tensor of zeroes when the optional data does not exist,
  • Split the dataset in two, based on the elements with unique features,
  • Run two forward passes, one for each dataset,
  • Sum the losses from both forward passes, followed by one backward pass.

I’ve realized I should stay away from BatchNorm and it seems that I also need to turn off the gradients of the optional module when it’s replaced by zero outputs. Is there maybe a completely better approach for this? If not, what are some other things to keep in mind as I do this (the weights of my optional layer end up in ‘nan’, but it could be a completely different issue)? And if my solution is already perfect, then this post can serve as a solution to anyone facing the same problem.

Here’s a simplified version of my model and training process:

def forward(self, x_always, x_optional=None):
    x1_out = self.fc(x_always)
    if x_optional is not None:
        x2_out = self.fc_optional(x_optional)
        x2_out = torch.zeros(x_always.size(0), 128) #(batch_size, fc_optional_out_size)    

    final_in =[x1_out, x2_out], dim=1)
    return self.fc_final(final_in)
# training loop

    for _p in model.fc_optional.parameters():     
        _p.requires_grad = False
    out = model(x1)
    loss1 = criterion(out, true_y)

    for _p in model.fc_optional.parameters():     
        _p.requires_grad = True
    out = model(x2, x2_optional)
    loss2 = criterion(out, true_y)

    loss = loss1 + loss2


I don’t think this would be needed. As long as the self.fc_optional module is not used, it won’t get any gradients etc.

This might work, but I would also recommend to check an alternative of returning the same number of outputs from both branches without concatenating the zeros to it. E.g. you could check if an additional layer could work, which would create the expected output feature dimension or maybe expand etc.

1 Like