Learn just 1 hidden vector for RNN

When learning hidden vector in RNN, (i.e. LSTM) it is a common practice to do the following:

self.hidden_params = torch.nn.Parameter(
                torch.randn(self.layers * directions, batch_size, self.hidden_size,
                        requires_grad=True))

self.cell_params = torch.nn.Parameter(
                torch.randn(self.layers * directions, batch_size, self.hidden_size, requires_grad=True))

and pass it to RNN (i.e. torch.nn.LSTM) like this:

self.rnn(inp, (self.hidden_params, self.cell_params))

but this way, the different hidden vector is learned for each batch position. (So instead of learning 1 hidden vector, we actually learn batch_size x hidden vectors). In conclusion, this leads to learning more parameters than needed and non-determinism in your network (different outputs for same data with different ordering). How can you learn just one hidden vector? I tried to expand just 1 vector, but it does not seem to work.

self.hidden_params = torch.nn.Parameter(
                torch.randn(self.layers * directions, 1, self.hidden_size, requires_grad=True)
                    .expand(-1, batch_size, -1)
            )
self.cell_params = torch.nn.Parameter(
                torch.randn(self.layers * directions, 1, self.hidden_size, requires_grad=True)
                    .expand(-1, batch_size, -1))

after training, here are the weights for self.hidden_params[:,0,:]

[[ 0.4211,  1.0297, -0.2927,  ...,  0.0929, -0.4851,  1.4456],
        [ 0.3757,  0.0579,  0.0987,  ..., -0.4427,  0.7364,  0.9339],
        [-1.1260,  0.1967, -0.8181,  ..., -0.0739, -1.4523, -0.7049],
        [ 1.6345, -0.5219,  0.8338,  ...,  0.8915, -0.5324, -1.3342]]

and here for self.hidden_params[:,1,:]

[[ 0.4042,  1.0263, -0.3423,  ...,  0.1028, -0.4955,  1.4253],
        [ 0.3941,  0.0707,  0.1011,  ..., -0.4712,  0.6899,  0.9811],
        [-1.1774,  0.1880, -0.7906,  ..., -0.0747, -1.4561, -0.7100],
        [ 1.6640, -0.5889,  0.8723,  ...,  0.9609, -0.5842, -1.2550]]

Similar, but not the same.

Can you try this?

def __init__(self):
    self.hidden_params = torch.nn.Parameter(torch.randn(self.layers * directions,  1, self.hidden_size))
    self.cell_params = torch.nn.Parameter(torch.randn(self.layers * directions, 1, self.hidden_size))
    ....

def forward(self, x):
    hidden_params = self.hidden_params.repeat(1, batch_size, 1)
    cell_params = self.cell_params.repeat(1, batch_size, 1)
    ... = self.rnn(inp, (self.hidden_params, self.cell_params))
   

Basically, try to repeat the dimensions in forward call.

Yeah, that works, thanks!