Below is a simplified version of a model I am developing
class ExampleModel(nn.Module):
def __init__(self, model_1, model_2):
super(ExampleModel, self).__init__()
self.model_1_name = 'Model_1'
self.model_1 = model_1
self.add_module(self.model_1_name, self.model_1)
self.model_2_name = 'Model_2'
self.model_2 = model_2
self.add_module(self.model_2_name, self.model_2)
self.fc_name = 'Linear_layer'
self.fc_linear = nn.Linear(self.2, 1)
self.add_module(self.fc_name, self.fc_linear)
def forward(self, input_a, input_b):
x1 = self.model_1(input_a)
x2 = self.model_2(input_b)
x = torch.cat((x1, x2), 1)
x = self.fc_linear(F.relu(x))
return x
My end goal is to support cases where data in the input batches might be missing and avoid updating model parameters based on missing data or any “dummy” data that is sent in.
As the data sent to forward is batched, a single batch might have missing data in some places and relevant data in other. Currently the missing data is replaced with a torch.zeros tensor.
Is it possible to create an architecture where the gradients generated from these inputs are excluded from the backwards pass?
Is there a better way to go about it which does not involve creating torch.zeros tensors as dummy data?