About half of the elements in the dataset have certain features, which need to go through a separate layer first. These are not missing features, but they simply don’t exist for all elements. Everything is part of the same problem/objective, therefore a separate model should not be used. My best guess for a solution is the following:
- Define two forward functions (two branches with shared modules), where the output of the optional branch is replaced with a tensor of zeroes when the optional data does not exist,
- Split the dataset in two, based on the elements with unique features,
- Run two forward passes, one for each dataset,
- Sum the losses from both forward passes, followed by one backward pass.
I’ve realized I should stay away from BatchNorm and it seems that I also need to turn off the gradients of the optional module when it’s replaced by zero outputs. Is there maybe a completely better approach for this? If not, what are some other things to keep in mind as I do this (the weights of my optional layer end up in ‘nan’, but it could be a completely different issue)? And if my solution is already perfect, then this post can serve as a solution to anyone facing the same problem.
Here’s a simplified version of my model and training process:
def forward(self, x_always, x_optional=None):
x1_out = self.fc(x_always)
if x_optional is not None:
x2_out = self.fc_optional(x_optional)
else:
x2_out = torch.zeros(x_always.size(0), 128) #(batch_size, fc_optional_out_size)
final_in = torch.cat([x1_out, x2_out], dim=1)
return self.fc_final(final_in)
# training loop
for _p in model.fc_optional.parameters():
_p.requires_grad = False
out = model(x1)
loss1 = criterion(out, true_y)
for _p in model.fc_optional.parameters():
_p.requires_grad = True
out = model(x2, x2_optional)
loss2 = criterion(out, true_y)
loss = loss1 + loss2
loss.backward()
opt.step()