Question about tensor lifespan

Hey! New to pytorch, coming from keras. Really loving it!

I dont know if the title of this post is correct as I did not really know how to formulate my question.

I am training a GAN in pytorch on the LSUN bedroom dataset. When I am training my discriminator with data from the generator I have noticed that I need to copy the data from the generator into a new tensor, otherwise my memory use will go up with every iteration. Why is this? I think there is something basic I have yet to understand so its best to ask instead of just accepting that that’s how you train multiple networks with eachothers outputs as inputs.

Here is the part of the code I am curious about:

fake = netG(get_noise(64))
#The line below is the one I am wondering about
fake = torch.Tensor(fake.size()).cuda().copy_(fake.data)
fake = Variable(fake)
prediction = netD(fake)

Without copying the data from the output of the generative network into a new tensor variable before using it as input to the discriminator network the memory use increases until an out-of-memory exception is thrown.

Using cuda for the networks and variables.

Thanks in advance!

The two lines:

fake = torch.Tensor(fake.size()).cuda().copy_(fake.data)
fake = Variable(fake)

break the computation history. This means that when you backprop through your discriminator it won’t backprop through the generator. It might also mean that by dropping the computation history, you’re allowing some memory to be freed.

Typically, your memory usage shouldn’t go up between mini-batches. It might happen if you’re holding on to the result of some computation. A common mistake is to do something like:

sample_loss = loss_fn(net(input), target)
sample_loss.backward()
# etc. etc.
total_loss += sample_loss

To try to average the loss over an epoch. The problem is that this may hold onto the entire computation history that computes the summed loss. Instead do:

sample_loss = loss_fn(net(input), target)
sample_loss.backward()
# etc. etc.
total_loss += sample_loss.data

(i.e. if you’re summing the loss across an epoch, you just want to accumulate the value, but not save the computation history)

1 Like