[resolved] LSTM with image data - how to save GPU memory?

I’m trying to use LSTM on image data and quickly running out of device memory. My loop runs something like this:

x = torch.rand(20, 2, 3, 224, 738)  # T x B x CHW
# init Variables h and c to proper size
y = []
for xt in x:
    xt = xt.cuda()
    yt, (h,c) = lstm(xt, (h, c))

What I want to achieve is this: take a time slice from x (called xt above), transfer it to device, perform the LSTM operation on GPU, and then get the xt back to host, preserving the Variable connections.
Any ideas?
Thanks :slight_smile:

Posting my solution here for reference. The solution divides the data into ‘groups’ of M images. The user specifies M based on the size of their GPU and resolution of the images.
This solution loops through the N x T x C x H x W data by figuring out B x G x C x H x W batches based on the M value.

# data is N x T x C x H x W
# target is N x T x d
M = 64  # no. of images that can fit on the GPU 
N, T = data.size(0), data.size(1)
G = min(T, M)  # no. of time slices that can fit on the GPU
B = min(N, M/G)  # batch size that can fit on the GPU

if train:
  data_var   = Variable(data, requires_grad=True)
  target_var = Variable(target, requires_grad=False)
  data_var   = Variable(data, volatile=True)
  target_var = Variable(target, volatile=True)

loss_accum = 0 
b_start = np.random.randint(N%B + 1)
for b in xrange(N/B):
  b_idx = b_start + torch.LongTensor(xrange(b*B, (b+1)*B))
  xb = torch.index_select(data_var, dim=0, index=Variable(b_idx))
  tb = torch.index_select(target_var, dim=0, index=Variable(b_idx).cuda())
  g_start = np.random.randint(T%G + 1)
  for g in xrange(T/G):
    g_idx = g_start + torch.LongTensor(xrange(g*G, (g+1)*G))
    xg = torch.index_select(xb, dim=1, index=Variable(g_idx))
    tg = torch.index_select(tb, dim=1, index=Variable(g_idx).cuda())
    output = model(xg, cuda=cuda, async=True)

    if criterion is not None:
      loss = criterion(output, tg) 
      loss_accum += loss.data[0]

      if train:
        # SGD step

where the model.reset_hidden_states() re-initializes them with random values from a normal distribution and ‘repackages’ them like in Help clarifying repackage_hidden in word_language_model