Cuda Out of Memory

This is self-contained script where you can run with python test_rnn.py.

It works with small number of hidden states on line 178 like 100 to even 1000. But once it reaches like 10000 and above which is what I need, it gets problematic.

'''
GPU LSTM
    1 single input
    1 single output
'''

from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt


class Net(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, bias, dropout):
        super(Net, self).__init__()
        self.rnn = nn.LSTM(input_size=input_size,
                           hidden_size=hidden_size,
                           num_layers=num_layers,
                           bias=bias,
                           dropout=dropout)


def input_var(i):
    test = np.array([i])
#     print(test.shape)
    # test = np.array([i])
    input_var = test.reshape(1, 1, 1)  # (seq_len, batch, input_size)
    input_var = torch.from_numpy(input_var).float()
    return input_var


def label_var(i):
    test = np.array([i*4])
    label_var = test.reshape(1, 1)  #
    label_var = torch.from_numpy(label_var).float()
    return label_var


class lstmModule:
    def __init__(self, input_size, hidden_size, num_layers, bias, dropout,
                 seq_len, batch_size, meta_lr, n_meta_iter):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bias = bias
        self.dropout = dropout
        self.seq_len = seq_len
        self.batch_size = batch_size
        self.meta_lr = meta_lr
        self.n_meta_iter = n_meta_iter

        self.net = Net(input_size=input_size,
                       hidden_size=hidden_size,
                       num_layers=num_layers,
                       bias=bias,
                       dropout=dropout)

        self.net.cuda()

        self.h0 = Variable(torch.randn(self.num_layers,
                                       self.batch_size,
                                       self.hidden_size)).cuda()

        self.c0 = Variable(torch.randn(self.num_layers,
                                       self.batch_size,
                                       self.hidden_size)).cuda()

        self.optimizer = optim.Adam(self.net.rnn.parameters(), lr=self.meta_lr)

        self.loss_lst = []

    def lstm_forward(self, seq_num, meta_num):
        def pseudo_loss(output, label):
            return torch.mean(torch.sum(torch.abs(output - label)))

        inp = input_var(seq_num)
        input = Variable(inp).cuda()

        lab = label_var(seq_num)
        label = Variable(lab).cuda()

        if seq_num == 0:

            # Ensure clear gradient buffer
            self.optimizer.zero_grad()
            self.loss_tot = [0 for i in range(self.hidden_size)]

            # Label concatenation
            self.label_all = label

            # LSTM
            output, hn = self.net.rnn(input, (self.h0, self.c0))
            output = 100 * output

            op = [output[:, :, i] for i in range(self.hidden_size)]

            self.output_all = op
            #             print('1 step length:', len(self.output_all))
            self.h, self.c = hn
        else:
            self.label_all = torch.cat((self.label_all, label), 0)
            output, hn = self.net.rnn(input, (self.h, self.c))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.h, self.c = hn
            self.output_all = [torch.cat((self.output_all[i], op[i]), 0) for i in range(self.hidden_size)]

        if seq_num == (self.seq_len - 1):
            # Get loss
            self.loss_tot = [self.loss_tot[i] + pseudo_loss(self.output_all[i], self.label_all) for i in range(self.hidden_size)]

            # Append loss
            self.loss_lst.append(sum(self.loss_tot).cpu().data.numpy()[0])

            # Backprop
            sum(self.loss_tot).backward()

            # Update optimizer
            self.optimizer.step()

        if seq_num == (self.seq_len - 1) and meta_num == (self.n_meta_iter - 1):
            # print(len(self.loss_lst))
            print('Loss 1', self.loss_tot[0].cpu().data.numpy())
            print('Loss 2', self.loss_tot[1].cpu().data.numpy())
            plt.clf()
            plt.plot()
            plt.title('Loss Curve')
            plt.plot(self.loss_lst, label='Loss Curve')
            plt.legend(loc='best')
            plt.savefig('loss.png')

    def lstm_check(self, seq_num):
        inp = input_var(seq_num)
        input = Variable(inp).cuda()
        lab = label_var(seq_num)
        label = Variable(lab).cuda()

        if seq_num == 0:
            # Ensure clear gradient buffer
            self.optimizer.zero_grad()
            self.loss_tot = [0 for i in range(self.hidden_size)]

            # Label concatenation
            self.label_all = label

            # LSTM
            output, hn = self.net.rnn(input, (self.h0, self.c0))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.output_all = op
            self.h, self.c = hn
        else:
            self.label_all = torch.cat((self.label_all, label), 0)
            output, hn = self.net.rnn(input, (self.h, self.c))
            output = 100 * output
            op = [output[:, :, i] for i in range(self.hidden_size)]
            self.h, self.c = hn
            self.output_all = [torch.cat((self.output_all[i], op[i]), 0) for i in range(self.hidden_size)]

        if seq_num == (self.seq_len - 1):
            print('-' * 10)
            print(self.output_all[0].cpu().data.numpy())
            print(self.label_all.cpu().data.numpy())
            print('-' * 10)
            print(self.output_all[1].cpu().data.numpy())
            print(self.label_all.cpu().data.numpy())

N_meta = 10
LR_meta = 0.1
N_seq = 4
batch_size = 1
layers = 4
input_size = 1
hidden_size = 15000

# Initialize and assign class to object once
# input_size, hidden_size, num_layers, bias, dropout, seq_len, batch_size, meta_lr, n_meta_iter):
print 'Initializing LSTM'
lstm = lstmModule(input_size, hidden_size, layers, True, 0.1, N_seq, batch_size, LR_meta, N_meta)
print 'Initialized LSTM'

# Run through meta iterations
print 'Training'
for j in range(N_meta):
    print('Meta iteration', j)
    # Run through each step
    for i in range(N_seq):
        lstm.lstm_forward(i, j)
print 'Done Training'

# Check
print('-' * 10)
print 'Checking'
for i in range(N_seq):
    lstm.lstm_check(i)
print 'Done Checking'