Predicting median with LSTM

I’m new with deep learning networks and LSTM and I have some trouble. My target values are often skewed and I want to make an LSTM that predicts more extremes more accurately. Please excuse me for me newbiness.

In RStudio, R, there is the ranger() function that is able to predict quantiles (a distribution), rather than simple point prediction. From the quantiles, I take the median and this gets me better predictions for extreme values. I want to do something similar with LSTM, however as I said before, I am new and don’t fully know what I should be doing.

Let’s have a look at the following LSTM:

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.linear = nn.Linear(hidden_size, output_size) 
    def forward(self, x):
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x


class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_quantiles, num_layers=1):
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, num_quantiles * output_size) 
    def forward(self, x):
        x, _ = self.lstm(x)
        quantiles = self.linear(x)
        return quantiles

Does the second model actually predict quantiles, compared to the first model? Does the first model practically already estimate a median prediction?

Say I set num_quantiles = 50. Is the median then the 25th prediction? :

Y_pred = model(X)
median_pred = Y_pred[:, :, quantile_number // 2::quantile_number]

Technically, if I set num_quantiles = 1 (which to my layman’s eyes makes the model similar to the first model), is then every prediction the median? So, is it similar to median_pred above?

Then regarding the loss function. What is an advised loss function in this case? I have tried:

def quantile_loss(predicted_values, target_values, quantile):
    median_pred =  predicted_values[:, :, quantile]
    error = median_pred - target_values
    loss = torch.abs(error)
    return loss.mean()

But if I look at the median (quantile = 0.5), I get what I assume is practically similar to nn.L1Loss()?