Partial model parameter updates due to missing input data

Below is a simplified version of a model I am developing

class ExampleModel(nn.Module):
    def __init__(self, model_1, model_2):
        super(ExampleModel, self).__init__()

        self.model_1_name = 'Model_1'
        self.model_1 = model_1
        self.add_module(self.model_1_name, self.model_1)

        self.model_2_name = 'Model_2'
        self.model_2 = model_2
        self.add_module(self.model_2_name, self.model_2)

        self.fc_name = 'Linear_layer'
        self.fc_linear = nn.Linear(self.2, 1)
        self.add_module(self.fc_name, self.fc_linear)

    def forward(self, input_a, input_b):
        x1 = self.model_1(input_a)
        x2 = self.model_2(input_b)
       
        x = torch.cat((x1, x2), 1)

        x = self.fc_linear(F.relu(x))

        return x

My end goal is to support cases where data in the input batches might be missing and avoid updating model parameters based on missing data or any “dummy” data that is sent in.

As the data sent to forward is batched, a single batch might have missing data in some places and relevant data in other. Currently the missing data is replaced with a torch.zeros tensor.

Is it possible to create an architecture where the gradients generated from these inputs are excluded from the backwards pass?
Is there a better way to go about it which does not involve creating torch.zeros tensors as dummy data?

It is doable if you have one availability state per sample, or can split columns into “core” and “optional” table; then multiplying some entries in either unreduced loss tensor or “optional” residual layer output by zeros does the trick. With varying missing column combinations it is much harder to do, without hindering performance.