About half of the elements in the dataset have certain features, which need to go through a separate layer first. These are not missing features, but they simply don’t exist for all elements. Everything is part of the same problem/objective, therefore a separate model should not be used. My best guess for a solution is the following:
- Define two forward functions (two branches with shared modules), where the output of the optional branch is replaced with a tensor of zeroes when the optional data does not exist,
- Split the dataset in two, based on the elements with unique features,
- Run two forward passes, one for each dataset,
- Sum the losses from both forward passes, followed by one backward pass.
I’ve realized I should stay away from BatchNorm and it seems that I also need to turn off the gradients of the optional module when it’s replaced by zero outputs. Is there maybe a completely better approach for this? If not, what are some other things to keep in mind as I do this (the weights of my optional layer end up in ‘nan’, but it could be a completely different issue)? And if my solution is already perfect, then this post can serve as a solution to anyone facing the same problem.
Here’s a simplified version of my model and training process:
def forward(self, x_always, x_optional=None): x1_out = self.fc(x_always) if x_optional is not None: x2_out = self.fc_optional(x_optional) else: x2_out = torch.zeros(x_always.size(0), 128) #(batch_size, fc_optional_out_size) final_in = torch.cat([x1_out, x2_out], dim=1) return self.fc_final(final_in)
# training loop for _p in model.fc_optional.parameters(): _p.requires_grad = False out = model(x1) loss1 = criterion(out, true_y) for _p in model.fc_optional.parameters(): _p.requires_grad = True out = model(x2, x2_optional) loss2 = criterion(out, true_y) loss = loss1 + loss2 loss.backward() opt.step()