AttributeError: 'CudnnRNN' object has no attribute '_nested_output'

I managed to make my PyTorch code work on CPU. While I was porting it over to GPU, I’m stuck at the backprogpagation with this error AttributeError: 'CudnnRNN' object has no attribute '_nested_output'

The API is designed this way because I need this API to interact with another code that requires such calls.

Simple LSTM (single input with multiple hidden states that are updated)

from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt


class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bias, dropout):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(input_size=input_size,
                           hidden_size=hidden_size,
                           num_layers=num_layers,
                           bias=bias,
                           dropout=dropout)

def input_var(i):
    test = np.array([i])
#     print(test.shape)
    # test = np.array([i])
    input_var = test.reshape(1, 1, 1)  # (seq_len, batch, input_size)
    input_var = torch.from_numpy(input_var).float()
    return input_var


def label_var(i):
    test = np.array([i*4])
    label_var = test.reshape(1, 1)  #
    label_var = torch.from_numpy(label_var).float()
    return label_var


class lstmModule:
    def __init__(self, input_size, hidden_size, num_layers, bias, dropout,
                 seq_len, batch_size, meta_lr, n_meta_iter):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.meta_lr = meta_lr
        self.n_meta_iter = n_meta_iter

        self.net = Net(input_size=input_size,
                       hidden_size=hidden_size,
                       num_layers=num_layers,
                       bias=bias,
                       dropout=dropout)

        self.net.cuda()

        self.h0 = Variable(torch.randn(self.num_layers,
                                       self.batch_size,
                                       self.hidden_size)).cuda()

        self.c0 = Variable(torch.randn(self.num_layers,
                                       self.batch_size,
                                       self.hidden_size)).cuda()

        self.optimizer = optim.Adam(self.net.rnn.parameters(), lr=self.meta_lr)

        self.loss_lst = []
        self.loss_lst2 = []

    def lstm_forward(self, seq_num, meta_num):
        print('i fed', seq_num)
        def pseudo_loss(output, label):
            return torch.mean(torch.sum(torch.abs(output - label)))

        inp = input_var(seq_num)
        input = Variable(inp).cuda()

        lab = label_var(seq_num)
        label = Variable(lab).cuda()

        if seq_num == 0:

            # Ensure clear gradient buffer
            self.optimizer.zero_grad()
            self.loss_tot = [0 for i in range(self.hidden_size)]

            # Label concatenation
            self.label_all = label

            # LSTM
            output, hn = self.net.rnn(input, (self.h0, self.c0))
            output = 100 * output

            op = [output[:, :, i] for i in range(self.hidden_size)]

            self.output_all = op
            #             print('1 step length:', len(self.output_all))
            self.h, self.c = hn
            print('Done', i)
        else:
            self.label_all = torch.cat((self.label_all, label), 0)
            output, hn = self.net.rnn(input, (self.h, self.c))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.h, self.c = hn
            self.output_all = [torch.cat((self.output_all[i], op[i]), 0) for i in range(self.hidden_size)]
            print('Done', i)
        # print('{} step length: {}'.format(i, len(self.output_all)))
        #             print('{} step output size: {}'.format(i, output.size()))
        #             print(self.output_all[0].size())
        print('-'*10)
        if seq_num == (self.seq_len - 1):
            # Get loss
            self.loss_tot = [self.loss_tot[i] + pseudo_loss(self.output_all[i], self.label_all) for i in range(self.hidden_size)]

            # Append loss
            self.loss_lst.append(self.loss_tot[0].cpu().data.numpy()[0])
            self.loss_lst2.append(self.loss_tot[1].cpu().data.numpy()[0])

            # Backprop
            print(len(self.loss_tot))
            print(self.loss_tot)
            for k in range(self.hidden_size):
                print('backprop', k)
                # print('backprop', k)
                #                 print(self.loss_tot[k].size())
                self.loss_tot[k].backward(retain_variables=True)

            # Update optimizer
            self.optimizer.step()

        if seq_num == (self.seq_len - 1) and meta_num == (self.n_meta_iter - 1):
            # print(len(self.loss_lst))
            print('Loss 1', self.loss_tot[0].cpu().data.numpy())
            print('Loss 2', self.loss_tot[1].cpu().data.numpy())
            plt.clf()
            plt.plot()
            plt.title('Loss Curve')
            plt.plot(self.loss_lst, label='Hidden 1')
            plt.plot(self.loss_lst2, label='Hidden 2')
            plt.legend(loc='best')
            plt.savefig('loss.png')

    def lstm_check(self, seq_num):
        inp = input_var(seq_num)
        input = Variable(inp).cuda()
        lab = label_var(seq_num)
        label = Variable(lab).cuda()

        if seq_num == 0:
            # Ensure clear gradient buffer
            self.optimizer.zero_grad()
            self.loss_tot = [0 for i in range(self.hidden_size)]

            # Label concatenation
            self.label_all = label

            # LSTM
            output, hn = self.net.rnn(input, (self.h0, self.c0))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.output_all = op
            self.h, self.c = hn
        else:
            self.label_all = torch.cat((self.label_all, label), 0)
            output, hn = self.net.rnn(input, (self.h, self.c))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.h, self.c = hn
            self.output_all = [torch.cat((self.output_all[i], op[i]), 0) for i
                               in
                               range(self.hidden_size)]

        if seq_num == (self.seq_len - 1):
            print('-' * 10)
            print(self.output_all[0].cpu().data.numpy())
            print(self.label_all.cpu().data.numpy())
            print('-' * 10)
            print(self.output_all[1].cpu().data.numpy())
            print(self.label_all.cpu().data.numpy())

N_meta = 100
LR_meta = 0.1
N_seq = 4
batch_size = 1
layers = 4
input_size = 1
hidden_size = 10

# Initialize and assign class to object once
# input_size, hidden_size, num_layers, bias, dropout, seq_len, batch_size, meta_lr, n_meta_iter):
print 'Initializing LSTM'
lstm = lstmModule(input_size, hidden_size, layers, True, 0.1, N_seq, batch_size, LR_meta, N_meta)
print 'Initialized LSTM'

# Run through meta iterations
print 'Training'
for j in range(N_meta):
    # Run through each step
    for i in range(N_seq):
        print('i start', i)
        lstm.lstm_forward(i, j)
print 'Done Training'

# Check
print 'Checking'
for i in range(N_seq):
    lstm.lstm_check(i)
print 'Done Checking'

Error:

Traceback (most recent call last):
  File "test.py", line 202, in <module>
    lstm.lstm_forward(i, j)
  File "test.py", line 127, in lstm_forward
    self.loss_tot[k].backward(retain_variables=True)
  File "/usr/local/lib/python2.7/dist-packages/torch/autograd/variable.py", line 158, in backward
    self._execution_engine.run_backward((self,), (gradient,), retain_variables)
  File "/usr/local/lib/python2.7/dist-packages/torch/autograd/function.py", line 208, in backward
    nested_gradients = _unflatten(gradients, self._nested_output)
AttributeError: 'CudnnRNN' object has no attribute '_nested_output'

It works in backpropagating one of the hidden state (the first one) but not the second one onwards.

We’re aware of that issue. Right now there’s an error in RNNs that occurs when you try to backprop through them multiple times. However, your code doesn’t need to do that, and would be much more efficient if it didn’t.

Replacing

for k in range(self.hidden_size):
    self.loss_tot[k].backward(retain_variables=True)

with

sum(self.loss_tot).backward()

will be much better. Backproping from a number of losses is equal to backproping from their sum (the gradients are accumulated). Additionally, this will save a lot of computation, because the backward will batch all operations for all the losses and execute them in one go.

1 Like

Thanks for the prompt reply on this issue.

Thankfully your recommendation works. Really appreciate it.

Cheers!
Ritchie