Implementing RNN with short/skip connections in PyTorch

I’m working on incorporating a stacked LSTM/GRU model with skip connections in PyTorch. It appears that PyTorch doesn’t inherently support skip connections, ruling out the use of the num_layers option. To address this, I’ve opted to create separate LSTM layers stacked on each other, where I concatenate the initial input to the output of each LSTM layer except the last one. The code seems to work, but I need someone to check if it works like a stacked LSTM/GRU model.

class Model(nn.Module):
    LAST = "LAST"
    MEAN = "MEAN"
    MAX = "MAX"

    def __init__(self, type, input, aggregation, output, shortConnections, seed=None):
        super(Model, self).__init__()
        if seed is not None:
            Model.enableDeterminism(seed)
        
        # LSTM or GRU
        self.__type = type
        
        # Timestep dimensions
        self.__input = input
        
        # Aggregation method of the hidden states of the final layer
        self.__aggregation = aggregation
        
        # Output of the last layer
        self.__previousOutput = -1
        
        # Linear output
        self.__output = output
        self.__linear = None
        
        # Misc
        self.__compiled = False
        self.__shortConnections = shortConnections
        self.__rnn__layers = []
        self.__shouldEvaluate = False

    def forward(self, inputs):
        x = inputs

        for i in range(len(self.__rnn__layers)):

            print(f"\nRNN Layer {i + 1}:\n")
            print(f"Inputs:\n{x}")

            outputs, _ = self.__rnn__layers[i](x)
            print(f"Outputs:\n{outputs}")
            
            # Concatentate the initial inputs for every output besides the one of the last layer 
            if i < len(self.__rnn__layers) - 1:
                x = torch.cat([inputs, outputs], dim=2)
            else:
                x = outputs

        print(f"\nUnprocessed Inputs:\n{x}")

        # Aggregate the hidden states of the last layer by taking the max across each dimension.
        if self.__aggregation == Model.MAX:
            x, _ = torch.max(x, dim=1)

        # Aggregate the hidden states of the last layer by taking the mean across each dimension.
        elif self.__aggregation == Model.MEAN:
            x = torch.mean(x, dim=1)

        # Grab the last hidden state of the last layer.
        elif self.__aggregation == Model.LAST:
            x = x[:, -1, :]
        else:
            raise ValueError("Invalid aggregation method")

        print(f"Inputs for linear with shape {x.shape}:\n{x}")

    def addRNNLayer(self, hiddenSize, bidirectional):
        if not self.__compiled:
            
            # Input size should be the output of the previous layer or the timestep - dimensions of the initial sequences.
            inputSize = self.__input if self.__previousOutput == -1 else self.__previousOutput

            if self.__type == torch.nn.LSTM.__name__:
                self.__rnn__layers.append(torch.nn.LSTM(input_size=inputSize, hidden_size=hiddenSize, bidirectional=bidirectional, num_layers=1, batch_first=True))

            elif self.__type == torch.nn.GRU.__name__:
                self.__rnn__layers.append(torch.nn.GRU(input_size=inputSize, hidden_size=hiddenSize, bidirectional=bidirectional, num_layers=1, batch_first=True))

            else:
                raise ValueError("Invalid RNN type")
            
            # In case of bidirectional layer
            self.__previousOutput = 2 * hiddenSize if bidirectional else hiddenSize

            # In case of short connections
            if self.__shortConnections:
                self.__previousOutput += self.__input

    def compile(self):
        self.__linear = torch.nn.Linear(self.__previousOutput, self.__output)
        self.__compiled = True

    def getRNNLayers(self):
        return self.__rnn__layers

    @staticmethod
    def saveModel(stateDict, path):
        torch.save(stateDict, f"{path}\\Model.pth")

    @staticmethod
    def enableDeterminism(seed):
        torch.manual_seed(seed)
        numpy.random.seed(seed)
        random.seed(seed)


# Dummy instance with 2 timesteps, each having 3 dimensions.
instance = torch.tensor([[[1, 2, 3], [4, 5, 6]]], dtype=torch.float32)

model = Model(type=torch.nn.LSTM.__name__, input=3, output=3, shortConnections=True, seed=0, aggregation=Model.LAST)
model.addRNNLayer(hiddenSize=4, bidirectional=False)
model.addRNNLayer(hiddenSize=4, bidirectional=False)
model.addRNNLayer(hiddenSize=4, bidirectional=False)
model.compile()

for layer in model.getRNNLayers():
    print(layer)

model.forward(instance)

Let me explain the code with an example to ensure everyone understands my thought process. Suppose we have 2 LSTM layers, each with a hidden size of 4. The input sequence is [[1, 2, 3], [4, 5, 6]], consisting of 2 timesteps with 3 dimensions each. In Layer-1, each timestep is processed to produce hidden states, h1 = [a, b, c, d] and h2 = [e, f, g, h], one for each timestep. The initial input is then concatenated to these hidden states, resulting in h1 = [1, 2, 3, a, b, c, d] and h2 = [4, 5, 6, e, f, g, h]. These hidden states form a sequence [h1, h2], where each is a 7-dimensional timestep. In Layer-2, each of these timesteps is processed to generate two new hidden states, h3 = [i, j, k, l] and h4 = [m, n, o, p]. The initial input is not concatenated to these hidden states since Layer-2 is the last LSTM layer. The final hidden states for each timestep are then aggregated using my custom method, which is not the focus here. The code is just a rough draft and it was implemented under the assumption that I understand how stacked LSTM/GRU layers work.