Handling edge cases/handling some inputs differently

I am building a torch.nn.module for graph convolutional neural networks. For concreteness, let us imagine the input is a configuration of points with (x,y) coordinates and one has to predict the output (scalar) of the configuration. I want an elegant way to handle the different case (edge case) of a configuration with no (x,y) points i.e. empty configuration with the same torch.nn.module. The output in this case is a scalar and I want to be able to optimize this scalar corresponding to the empty configuration together with the neural network weights during the training stage of the neural network. The scalar may be considered as an additional weight in the model.
One problem that may arise:
-Each batch that is input to the neural network may have a combination of the empty and non-empty configurations


To handle such things I usually do the following pseudo code:

cond = bath_data.is_empty? (gives a Tensor of size (n_batch,) containing 0/1)

out_true = your_non_empty_code(batch_data) # Compute everything as if non-empty

out_false = self.empty_value.expand_as(out_true) # Your param for the empty ones

res = out_true * (1 - cond) + out_false * cond
1 Like