How does Pytorch pack LSTM weights to pass to cuDNN?

I have trained a model that uses nn.LSTMCell. For reasons of throughput, I want to directly call cuDNN’s cudnnRNNForwardInference. So I have to export weights of nn.LSTMCell. For other layers such as linear or convolution, this wasn’t hard. However, this is very difficult for nn.LSTMCell because nn.LSTMCell takes 2 sets of weights and 2 sets of biases, while cudnnRNNForwardInference only takes a single set of weight and a single set of bias:

nn.LSTMCell

~LSTMCell.weight_ih – the learnable input-hidden weights, of shape (4*hidden_size, input_size)

~LSTMCell.weight_hh – the learnable hidden-hidden weights, of shape (4*hidden_size, hidden_size)

~LSTMCell.bias_ih – the learnable input-hidden bias, of shape (4*hidden_size)

~LSTMCell.bias_hh – the learnable hidden-hidden bias, of shape (4*hidden_size)

cudnnRNNForwardInference

cudnnStatus_t cudnnRNNForwardInference(
    cudnnHandle_t                   handle,
    const cudnnRNNDescriptor_t      rnnDesc,
    const int                       seqLength,
    const cudnnTensorDescriptor_t  *xDesc,
    const void                     *x,
    const cudnnTensorDescriptor_t   hxDesc,
    const void                     *hx,
    const cudnnTensorDescriptor_t   cxDesc,
    const void                     *cx,
    const cudnnFilterDescriptor_t   wDesc,
    const void                     *w,
    const cudnnTensorDescriptor_t   *yDesc,
    void                           *y,
    const cudnnTensorDescriptor_t   hyDesc,
    void                           *hy,
    const cudnnTensorDescriptor_t   cyDesc,
    void                           *cy,
    void                           *workspace,
    size_t                          workSpaceSizeInBytes)

How would I pack each of the weights: weight_ih, weight_hh, bias_ih, bias_hh s.t. it can be passed to cuDNN’s cudnnRNNForwardInference . as wDesc parameter?

See Parameters In Tensorflow Keras RNN and CUDNN RNN - Kaixi Hou’s Log - great write up on this!

this function will pack them in the order used by cudnn as per the referenced article above and the pytorch documents at LSTM — PyTorch 2.0 documentation. Note pytorch switches the ordering of the wtig and wtio from the previous article which uses karas.

convert lstm weights to cudnn format

def get_cudnn_lstm_weights(lstm):
    num_layers = lstm.num_layers
    hidden_size = lstm.hidden_size
    wts = []
    for param in lstm.state_dict():
        param_val = lstm.state_dict()[param]
        wts.append(param_val)

    all_wts = []
    idx = 0
    for i in range(0, num_layers):
        wtii = wts[idx][:hidden_size,:]
        wtif = wts[idx][hidden_size:hidden_size*2,:]
        wtig = wts[idx][hidden_size*2:hidden_size*3,:]
        wtio = wts[idx][hidden_size*3:hidden_size*4,:]
        idx = idx + 1

        wthi = wts[idx][:hidden_size,:]
        wthf = wts[idx][hidden_size:hidden_size*2,:]
        wthg = wts[idx][hidden_size*2:hidden_size*3,:]
        wtho = wts[idx][hidden_size*3:hidden_size*4,:]
        idx = idx + 1

        bii = wts[idx][:hidden_size]
        bif = wts[idx][hidden_size:hidden_size*2]
        big = wts[idx][hidden_size*2:hidden_size*3]
        bio = wts[idx][hidden_size*3:hidden_size*4]
        idx = idx + 1

        bhi = wts[idx][:hidden_size]
        bhf = wts[idx][hidden_size:hidden_size*2]
        bhg = wts[idx][hidden_size*2:hidden_size*3]
        bho = wts[idx][hidden_size*3:hidden_size*4]
        idx = idx + 1

        wts1 = [wtii, wtif, wtio, wtig, wthi, wthf, wtho, wthg]
        b1 = [bii, bif, bio, big, bhi, bhf, bho, bhg]

        shape = [-1]
        weights = [torch.reshape(torch.transpose(x, 0, 1), shape) for x in wts1]
        biases = [torch.reshape(x, shape) for x in b1]
        cudnnwts = torch.concat(weights + biases, axis=0)
        all_wts.append(cudnnwts)

    full_cudnnwts = torch.concat(all_wts, axis=0) if num_layers > 1 else all_wts[0]
    return full_cudnnwts