I am trying to predict a set of 7 classes in time steps of 4. The predicted values are numerical. For each time step, the sum of the predicted values of the 7 classes needs to be 100. One class can have a value of 100, but then the other classes by definition are 0. My LSTM outputs a tensor with shape [batch_size, sequence_length, output_size], where batch_size = 64, sequence_length = 4 and output_size = 7. However, currently my LSTM sometimes predicts such as the following:
[[0, 0, 0, 44, 0, 6, 0] [100, 0, 0, 0, 0, 0, 0] [78, 0, 0, 5, 0, 0, 0] [0, 30, 0, 0, 70, 0, 0]]
This would depict a four step sequence of the classes. As you can see, in timestep 1: 44+6 != 100, as well as in timestep 3: 78+5 here. However, I want my LSTM model to predict a total that sums to 100 for every timestep.
My LSTM model:
class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, output_size, num_layers=1): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, output_size) init.xavier_uniform_(self.linear.weight) def forward(self, x): x, _ = self.lstm(x) x = self.linear(x) x = torch.clamp(x, min=0, max=100) return x
Here I am using
x = torch.clamp(x, min=0, max=100) to constrain the output of one class per time step to be at minimum 0 and at maximum 100. Before, it sometimes predicted negative values. I now wonder if there is a similar way that constrains the output of the LSTM to be a sum of 100 and nothing more or less. So say something like:
x = torch(sum(x) = 100)
Currently, I simply rescale my output (x/sum * 100) after it has come out of the model, but I wonder if I can do this within the model, so that I force the model to predict something that (closely) adds up to 100. Because now it sometimes predicts sums of lower than 50 (and sometimes even 0) and then rescaling will only damage the accuracy I think.
Is this doable within the LSTM model? Or do I simply need to be more clever when rescaling it later? Thanks in advance!