PyTorch - Initializing recurrent matrices to identity/orthogonal matrix

I have this code that initializes my recurrent matrices (Whx and Whh) of LSTM to the zero matrix now. However, I wanted to initialize them to the identity matrix.

state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = Variable(inputs.data.new(*state_shape).zero_())

I have a little confused as to how to do this neatly since the shape is 3D and I cannot use nn.init.eye which is for 2D tensors. Also, the first dimension is number of cells in the sequence and not batch size which makes it harder to do this in my opinion.

Is there any neat way to do this? Please let me know. Thank you!

Hi

Why would your weights include a 3rd (batch_size?) dimension? Your state, which you show does and should, but for the weights that seems unusual.
That said, you can mostly assign to parameter’s .data submatrices.

Best regards

Thomas