My LSTM model is underfitted

Hello everyone, I’m new to PyTorch and currently are stuck with training a LSTM model. I’m trying to reproduce result from this: Trading Momentum Transformer (the model is defined in mom_trans/deep_momentum_network.py)

Briefly, this work aim to use LSTM for a momentum trading strategy. Input is close price of various tickers. From this close price, target_returns can be calculated. The output is positions for each ticker on each time step. The aim is to maximize the sharpe ratio between target_returns and positions.

The output of the LSTM is followed by a time distributed, fully-connected layer with a activation function tanh(), which is squashing function that directly outputs positions.

The original repo use Tensorflow but I need to migrate to PyTorch for compatibility with my system. Below is my implementation in PyTorch:

# Because we want to maximize sharpe ratio, the sharpe loss will return -sharpe_ratio
class SharpeLoss(nn.Module):
    def __init__(self, output_size: int = 1):
        super(SharpeLoss, self).__init__()
        self.output_size = output_size

    def forward(self, y_true, weights):
        print(f'{y_true.shape} | {weights.shape}')
        captured_returns = y_true * weights
        mean_returns = torch.mean(captured_returns)
        variance_returns = torch.mean(torch.square(captured_returns)) - torch.square(mean_returns)
        std_returns = torch.sqrt(variance_returns + 1e-9)
        sharpe_loss = -mean_returns * torch.sqrt(torch.tensor(252.0)) / std_returns
        return sharpe_loss


# Since PyTorch does not support the equivalence of tf.keras.layers.TimeDistributed
# we need to implement one
class TimeDistributed(nn.Module):
    def __init__(self, module, output_size, batch_first=True, activation=None):
        super(TimeDistributed, self).__init__()
        self.module = module
        self.output_size = output_size
        self.activation = activation
        self.batch_first = batch_first

    def forward(self, x):
        # Input x shape: (batch_size, time_steps, input_size)
        if self.batch_first:
            batch_size, time_steps, _ = x.size()
        else:
            time_steps, batch_size, _ = x.size()

        # Reshape to (batch_size * time_steps, input_size)
        x_reshaped = x.contiguous().view(-1, x.size(-1))

        output = self.module(x_reshaped)
        output = self.activation(output)

        if self.batch_first:
            # Reshape back to (batch_size, time_steps, output_size)
            output = output.view(batch_size, time_steps, self.output_size)
        else:
            # Reshape back to (time_steps, batch_size, output_size)
            output = output.view(time_steps, batch_size, self.output_size)

        return output

class LstmDeepMomentumNetworkModel(nn.Module):

    def __init__(self, hidden_layer_size, dropout_rate, **params):
        super(LstmDeepMomentumNetworkModel, self).__init__()

        params = params.copy()

        self.time_steps = int(params["total_time_steps"])
        self.input_size = int(params["input_size"])
        self.output_size = int(params["output_size"])
        self.hidden_layer_size = hidden_layer_size
        self.dropout_rate = dropout_rate
        self.evaluate_diversified_val_sharpe = params["evaluate_diversified_val_sharpe"]
        self.force_output_sharpe_length = params["force_output_sharpe_length"]

        self.lstm = nn.LSTM(
            input_size = self.input_size,
            bias = True,
            batch_first = True,
            hidden_size = self.hidden_layer_size
        )

        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_layer_size, self.output_size)
        self.time_distributed = TimeDistributed(
            module=self.fc,
            output_size=self.output_size,
            activation=nn.Tanh(),
            batch_first=True
        )
        self._reinitialize()

    def _reinitialize(self):
        """
        Tensorflow/Keras-like initialization
        """
        for name, p in self.named_parameters():
            if 'lstm' in name:
                if 'weight_ih' in name:
                    nn.init.xavier_uniform_(p.data)
                elif 'weight_hh' in name:
                    nn.init.orthogonal_(p.data)
                elif 'bias_ih' in name:
                    p.data.fill_(0)
                    # Set forget-gate bias to 1
                    n = p.size(0)
                    p.data[(n // 4):(n // 2)].fill_(1)
                elif 'bias_hh' in name:
                    p.data.fill_(0)
            elif 'fc' in name:
                if 'weight' in name:
                    nn.init.xavier_uniform_(p.data)
                elif 'bias' in name:
                    p.data.fill_(0)

    def forward(self, x):
        # x (batch_size, time_steps, input_size)
        lstm_out, _ = self.lstm(x) # (batch_size, time_steps, hidden_size)
        lstm_out = self.dropout(lstm_out)
        output = self.time_distributed(lstm_out) # (batch_size, time_steps, output_size)
        return output

# Training function
def train_n_epochs(device, n_epochs, model, train_loader, valid_loader, criterion, optimizer, early_stop_epoch, max_grad_norm):
    valid_loss_min = np.Inf  # track change in validation loss
    train_loss_set = []
    valid_loss_set = []
    invariant_epochs = 0

    model.to(device)
    print(f'Using {device} for training')
    best_model = None

    for epoch_i in range(n_epochs):
        # keep track of training and validation loss
        train_loss, valid_loss = 0.0, 0.0

        # Model for training
        model.train()

        for i, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(target, output)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()
            train_loss += loss.item()

        # Model validation
        model.eval()
        with torch.no_grad():
            for i, (data, target) in enumerate(valid_loader):
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = criterion(target, output)
                valid_loss += loss.item()

        # Compute average sharpe loss
        train_loss /= len(train_loader.dataset)
        valid_loss /= len(valid_loader.dataset)

        train_loss_set.append(train_loss)
        valid_loss_set.append(valid_loss)

        print(f'Epoch: {epoch_i + 1} Training Loss: {train_loss:.6f} Validation Loss: {valid_loss:.6f}')

        # if validation loss gets smaller, save the model
        if valid_loss <= valid_loss_min:
            print(f'Validation loss decreased ({valid_loss_min:.6f} --> {valid_loss:.6f}). Saving model...')
            valid_loss_min = valid_loss
            invariant_epochs = 0

            best_model = copy.deepcopy(model)
        else:
            invariant_epochs += 1

        if invariant_epochs >= early_stop_epoch:
            print(f"Early Stop at Epoch [{epoch_i + 1}]: Performance hasn't improved for {early_stop_epoch} epochs")
            break

    return train_loss_set, valid_loss_set, valid_loss_min, best_model

Here is how I train it:

criterion = SharpeLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

_, _, min_val_loss, best_model = train_n_epochs(
                        'cpu', n_epochs=300, model=model,
                        train_loader=train_loader, valid_loader=valid_loader,
                        criterion=criterion, optimizer=optimizer,
                        early_stop_epoch=30, max_grad_norm=1.0)

The data preparation is the same with the original repo.

The problem is the loss when training is high. I expect it to decrease to around -1.x → -2.x (as output by the original code with tensorflow) but my training loss is only get to ~-0.5 and stop decreasing. The valid loss is even more higher. I have tried with different hyperparameter sets but still cannot find optimal hyperparameters so I think my model or loss function have something wrong…