Need help with a first RNN module

Hello all! I’ve recently started learning Pytorch and as a learning exercise I am trying to create a RNN + fully connected module. Unfortunately when I try to use the module I find that my network is not learning. In fact when I do a trace I find that my parameter gradients are all remaining None. I’d appreciate any debugging tips or critiques of the module.

class Sequence(nn.Module):
    def __init__(self, input_size, hidden_units, output_size):
        super(Sequence, self).__init__()

        self.hidden_units = hidden_units

        self.l1 = nn.Linear(input_size + hidden_units, hidden_units)
        self.swish = Swish()
        self.do = nn.Dropout(0.2)
        self.l2 = nn.Linear(hidden_units, output_size)
        self.tanh = nn.Tanh()

        self.reset_mem()

    def forward(self, input_seq):
        '''
        Start off treating entire training set as one sequence

        input_seq - (very large, 1)
        '''

        if self.a1 is None:
            self.a1 = torch.zeros(self.hidden_units, dtype=torch.float)

        outputs = []
        for i, input_t in enumerate(input_seq):
            input_t = torch.tensor(input_t)  # (4)
            combined = torch.cat((input_t, self.a1))  # (132)
            h1 = self.l1(combined)  # (128)
            self.a1 = self.swish(h1)  # (128)
            do1 = self.do(self.a1)  # (128)
            h2 = self.l2(do1)  # (1)
            output = self.tanh(h2)  # (1)
            outputs.append(output)

        # Outputs is now 12533 long (using default data)

        outputs = torch.tensor(outputs, requires_grad=True)
        outputs = outputs.reshape((len(outputs), 1))  # (12533, 1)

        return outputs

    def reset_mem(self):
        self.a1 = None

Training code looks like this


    X = np.array([i[0] for i in training_data])
    X = torch.Tensor(X)  # This is (12533, 4)
    y = [i[1] for i in training_data]
    y = torch.Tensor(y)  # This is (12533, 1)

    model = Sequence(4, 128, 1)
    loss_fn = torch.nn.MSELoss(size_average=False)

    # Use optimizer from the sequence prediction tutorial
    optimizer = torch.optim.LBFGS(model.parameters(), lr=0.8)

    for i in range(training_epochs):
        def closure():
            optimizer.zero_grad()
            y_pred = model(X)
            loss = loss_fn(y_pred, y)
            loss.backward()
            return loss

        optimizer.step(closure)

The optimizer reports that for each Parameter the .grad is None and immediately exits. (In _gather_flat_grad() from the LBFGS code)

so, i think i know the problem. let’s see.

you are doing this, to signal that your model requires gradients:

outputs = torch.tensor(outputs, requires_grad=True)

A tensor constructor torch.tensor(outputs, requires_grad=True) has taken outputs here which are lists of Tensors. The constructor doesn’t know how to backprop through a list [of Tensors or numbers].

What you probably want to do here is:

outputs = torch.cat(outputs) # unlike a constructor which treats inputs as opaque entities, torch.cat knows exactly how to backprop through the list of Tensors

You sir are fantastic! This worked and makes perfect sense.

Not surprisingly I immediately ran into another question (which there is probably good documentation for somewhere).

Initially I got

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Which I then followed and restarted training. For each LBFGS internal iteration it takes a longer and longer time.

Is the graph growing and growing with each iteration? More generally what would you suggest I read / search for to better understand what is happening here.

Thanks again for your solution to my problem and the great library!

from the surface of your code snippet, it looks fine, so I’m not sure where the memory growth / holding onto the graph is coming from.

A common pattern that is dangerously growing the graph across iterations is:

total_loss += loss # for reporting / logging purposes

here, loss is a torch zero-dimensional Tensor, so it records gradient (and hence holds onto the graph).
The right thing to do would be:

total_loss += loss.item() # loss.item() converts it into a python number