I’m trying to use LSTM on image data and quickly running out of device memory. My loop runs something like this:
x = torch.rand(20, 2, 3, 224, 738) # T x B x CHW
# init Variables h and c to proper size
y = []
for xt in x:
xt = xt.cuda()
yt, (h,c) = lstm(xt, (h, c))
y.append(yt)
What I want to achieve is this: take a time slice from x (called xt above), transfer it to device, perform the LSTM operation on GPU, and then get the xt back to host, preserving the Variable connections.
Any ideas?
Thanks 
Posting my solution here for reference. The solution divides the data into ‘groups’ of M
images. The user specifies M
based on the size of their GPU and resolution of the images.
This solution loops through the N x T x C x H x W
data by figuring out B x G x C x H x W
batches based on the M
value.
# data is N x T x C x H x W
# target is N x T x d
M = 64 # no. of images that can fit on the GPU
N, T = data.size(0), data.size(1)
G = min(T, M) # no. of time slices that can fit on the GPU
B = min(N, M/G) # batch size that can fit on the GPU
if train:
data_var = Variable(data, requires_grad=True)
target_var = Variable(target, requires_grad=False)
else:
data_var = Variable(data, volatile=True)
target_var = Variable(target, volatile=True)
loss_accum = 0
b_start = np.random.randint(N%B + 1)
for b in xrange(N/B):
b_idx = b_start + torch.LongTensor(xrange(b*B, (b+1)*B))
xb = torch.index_select(data_var, dim=0, index=Variable(b_idx))
tb = torch.index_select(target_var, dim=0, index=Variable(b_idx).cuda())
model.reset_hidden_states(B)
g_start = np.random.randint(T%G + 1)
for g in xrange(T/G):
g_idx = g_start + torch.LongTensor(xrange(g*G, (g+1)*G))
xg = torch.index_select(xb, dim=1, index=Variable(g_idx))
tg = torch.index_select(tb, dim=1, index=Variable(g_idx).cuda())
model.detach_hidden_states()
output = model(xg, cuda=cuda, async=True)
if criterion is not None:
loss = criterion(output, tg)
loss_accum += loss.data[0]
if train:
# SGD step
optim.learner.zero_grad()
loss.backward()
optim.learner.step()
where the model.reset_hidden_states()
re-initializes them with random values from a normal distribution and ‘repackages’ them like in Help clarifying repackage_hidden in word_language_model
2 Likes