What is the layout of the GRU weights and bias?

I want to be able to initialize specific parts of a GRUCell weight and bias in different ways, ie for reset gate vs update gate vs candidate. How can I found out the layout of the ih_weight, hh_weight, ih_bias and hh_bias tensors?

I’d think it might be worth trying whether the layout at the end of the GRU documentation works for GRUCell as well (probably that was added to the RNN/GRU/LSTM-Modules but not to the cells).

Best regards

Thomas

This bit? Sounds plausible. Would be good to get a link to the source-code of the GRUCell implementatino (I can get as far as the .py wrapper, but then it zooms off into C+±land, and not obvious how to find it in the source tree).

seems this is indeed the layout. by experimenting a bit:

import torch
from torch import nn, optim


def run_findz():
    """
    I 3 O 4
    gru.weight_ih.size() torch.Size([12, 3])
    gru.bias_hh.size() torch.Size([12])
    output tensor([[ 0.5000,  0.5000,  0.5000,  0.5000],
            [ 1.0000,  1.0000,  1.0000,  1.0000],
            [ 1.5000,  1.5000,  1.5000,  1.5000],
            [ 2.0000,  2.0000,  2.0000,  2.0000]])
    output tensor([[ 1.,  1.,  1.,  1.],
            [ 2.,  2.,  2.,  2.],
            [ 3.,  3.,  3.,  3.],
            [ 4.,  4.,  4.,  4.]])
    output tensor([[ 1.0000,  1.0000,  1.0000,  1.0000],
            [ 1.5000,  1.5000,  1.5000,  1.5000],
            [ 2.0000,  2.0000,  2.0000,  2.0000],
            [ 2.5000,  2.5000,  2.5000,  2.5000]])    

    (so z is block 1)
    """
    print('')
    print('find z')

    N = 4
    I = 3
    O = 4
    print('I', I, 'O', O)
    gru = nn.GRUCell(I, O)
    print('gru.weight_ih.size()', gru.weight_ih.size())
    print('gru.bias_hh.size()', gru.bias_hh.size())
    gru.weight_ih.data.fill_(0)
    gru.weight_hh.data.fill_(0)
    gru.bias_ih.data.fill_(0)
    gru.bias_hh.data.fill_(0)

    for i in range(3):
        gru.bias_hh.data.fill_(0)
        gru.bias_hh.data[i * 4:(i + 1) * 4].fill_(20)

        input = torch.zeros(N, I)
        state = torch.zeros(N, O)
        for i in range(N):
            state[i].fill_(i + 1)
        output = gru(input, state)
        print('output', output)


def run_findn():
    """
    find N
    I 3 O 4
    gru.weight_ih.size() torch.Size([12, 3])
    gru.bias_hh.size() torch.Size([12])
    i 0 output tensor([[ 0.,  0.,  0.,  0.]])
    i 1 output tensor([[ 0.,  0.,  0.,  0.]])
    i 2 output tensor([[ 1.,  1.,  1.,  1.]])    

    (so n is block 2)
    """
    print('')
    print('find n')
    N = 1
    I = 3
    O = 4
    print('I', I, 'O', O)
    gru = nn.GRUCell(I, O)
    print('gru.weight_ih.size()', gru.weight_ih.size())
    print('gru.bias_hh.size()', gru.bias_hh.size())
    gru.weight_ih.data.fill_(0)
    gru.weight_hh.data.fill_(0)
    gru.bias_ih.data.fill_(0)
    gru.bias_hh.data.fill_(0)

    gru.bias_hh.data[4:8].fill_(-20)

    # gru.bias_hh.data[0:4].fill_(20)
    for i in range(3):
        gru.bias_ih.data.fill_(0)
        gru.bias_ih.data[i * 4:(i + 1) * 4].fill_(20)

        input = torch.zeros(N, I)
        state = torch.zeros(N, O)
        output = gru(input, state)
        print('i', i, 'output', output)


def run_findr():
    """
    output:

    find r
    I 3 O 4
    gru.weight_ih.size() torch.Size([12, 3])
    gru.bias_hh.size() torch.Size([12])
    output tensor([[ 0.,  0.,  0.,  0.]])
    output tensor([[ 1.,  1.,  1.,  1.]])

    (so r is block 0)
    """
    print('')
    print('find r')
    N = 1
    I = 3
    O = 4
    print('I', I, 'O', O)
    gru = nn.GRUCell(I, O)
    print('gru.weight_ih.size()', gru.weight_ih.size())
    print('gru.bias_hh.size()', gru.bias_hh.size())
    gru.weight_ih.data.fill_(0)
    gru.weight_hh.data.fill_(0)
    gru.bias_ih.data.fill_(0)
    gru.bias_hh.data.fill_(0)

    gru.bias_hh.data[4:8].fill_(-20)
    gru.bias_hh.data[8:12].fill_(20)

    gru.bias_hh.data[0:4].fill_(-100)

    input = torch.zeros(N, I)
    state = torch.zeros(N, O)
    output = gru(input, state)
    print('output', output)

    gru.bias_hh.data[0:4].fill_(20)

    input = torch.zeros(N, I)
    state = torch.zeros(N, O)
    output = gru(input, state)
    print('output', output)


if __name__ == '__main__':
    # run_findz()
    # run_findn()
    run_findr()

Oh, it’s C++ very late, before that you get a chance to see a pure Python implementation in torch.nn._functions.GRUCell but yes, the torch.nn._functions module certainly is one of the more obscure corners of PyTorch (and it is bound to vanish either during the great refactoring/C++Torch that’s going on or the RNN bits when the great RNN overhaul comes).

Best regards

Thomas