How to implement convolutional GRUs/LSTMs/RNNs

I want to have the standard LSTM/GRU/RNN set up but swap the linear function with a convolution. Is that possible to do in Pytorch in an clean and efficient manner?

Ideally it still works with packing, varying sequence length etc.

Small sample code of a trivial way to pass data through it would be super useful like:

# Based on Robert Guthrie tutorial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pdb import set_trace as st

torch.manual_seed(1)

def step_by_step(net,sequence,hidden):
    '''
    an example of an LSTM processing all the sequence one token at a time (one time step at a time)
    '''
    ## process sequence one element at a time
    print()
    print('start processing sequence')
    for i, token in enumerate(sequence):
        print(f'-- i = {i}')
        #print(f'token.size() = {token.size()}')
        ## to add fake batch_size and fake seq_len
        h_n, c_n = hidden # hidden states, cell state
        processed_token = token.view(1, 1, -1) # torch.Size([1, 1, 3])
        print(f'processed_token.size() = {processed_token.size()}')
        print(f'h_n.size() = {h_n.size()}')
        #print(f'processed_token = {processed_token}')
        #print(f'h_n = {h_n}')
        # after each step, hidden contains the hidden state.
        out, hidden = lstm(processed_token, hidden)
    ## print results
    print()
    print(out)
    print(hidden)

def whole_seq_all_at_once(lstm,sequence,hidden):
    '''
        alternatively, we can do the entire sequence all at once.
        the first value returned by LSTM is all of the hidden states throughout
        #the sequence.
        The second is just the most recent hidden state
        (compare the last slice of "out" with "hidden" below, they are the same)
        The reason for this is that:
        "out" will give you access to all hidden states in the sequence
        "hidden" will allow you to continue the sequence and backpropagate,
        by passing it as an argument  to the lstm at a later time
        Add the extra 2nd dimension
    '''
    h, c = hidden
    Tx = len(sequence)
    ## concatenates list of tensors in the dim 0, i.e. stacks them downwards creating new rows
    sequence = torch.cat(sequence) # (5, 3)
    ## add a singleton dimension of size 1
    sequence = sequence.view(len(sequence), 1, -1) # (5, 1, 3)
    print(f'sequence.size() = {sequence.size()}')
    print(f'h.size() = {h.size()}')
    print(f'c.size() = {c.size()}')
    out, hidden = lstm(sequence, hidden)
    ## "out" will give you access to all hidden states in the sequence
    print()
    print(f'out = {out}')
    print(f'out.size() = {out.size()}') # (5, 1, 25)
    ##
    h_n, c_n = hidden
    print(f'h_n = {h_n}')
    print(f'h_n.size() = {h_n.size()}')
    print(f'c_n = {c_n}')
    print(f'c_n.size() = {c_n.size()}')

if __name__ == '__main__':
    ## model params
    hidden_size = 6
    input_size = 3
    lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size)
    ## make a sequence of length Tx (list of Tx tensors)
    Tx = 5
    sequence = [torch.randn(1, input_size) for _ in range(Tx)]  # make a sequence of length 5
    ## initialize the hidden state.
    hidden = (torch.randn(1, 1, hidden_size), torch.randn(1, 1, hidden_size))
    #step_by_step(lstm,sequence,hidden)
    whole_seq_all_at_once(lstm,sequence,hidden)
    print('DONE \a')

Cross posted: