I’m new with deep learning networks and LSTM and I have some trouble. My target values are often skewed and I want to make an LSTM that predicts more extremes more accurately. Please excuse me for me newbiness.
In RStudio, R, there is the ranger() function that is able to predict quantiles (a distribution), rather than simple point prediction. From the quantiles, I take the median and this gets me better predictions for extreme values. I want to do something similar with LSTM, however as I said before, I am new and don’t fully know what I should be doing.
Let’s have a look at the following LSTM:
class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True) self.linear = nn.Linear(hidden_size, output_size) def forward(self, x): x, _ = self.lstm(x) x = self.linear(x) return x
class LSTMModel(nn.Module): def __init__(self, input_size, hidden_size, output_size, num_quantiles, num_layers=1): super().__init__() self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True) self.linear = nn.Linear(hidden_size, num_quantiles * output_size) def forward(self, x): x, _ = self.lstm(x) quantiles = self.linear(x) return quantiles
Does the second model actually predict quantiles, compared to the first model? Does the first model practically already estimate a median prediction?
Say I set
num_quantiles = 50. Is the median then the 25th prediction? :
Y_pred = model(X) median_pred = Y_pred[:, :, quantile_number // 2::quantile_number]
Technically, if I set
num_quantiles = 1 (which to my layman’s eyes makes the model similar to the first model), is then every prediction the median? So, is it similar to
Then regarding the loss function. What is an advised loss function in this case? I have tried:
def quantile_loss(predicted_values, target_values, quantile): median_pred = predicted_values[:, :, quantile] error = median_pred - target_values loss = torch.abs(error) return loss.mean()
But if I look at the median (quantile = 0.5), I get what I assume is practically similar to nn.L1Loss()?