here is what I’m trying to implement:
We calculate loss based on F(X)
, as usual. And we also define “adversarial loss” which is a loss based on F(X + e)
. e
is defined as dF(X)/dX
multiplied by some constant. Both loss and adversarial loss are backpropagated for the total loss.
In tensorflow, this part (getting dF(X)/dX
) can be coded like below:
grad, = tf.gradients( loss, X )
grad = tf.stop_gradient(grad)
e = constant * grad
Below is my pytorch code:
class DocReaderModel(object):
def __init__(self, embedding=None, state_dict=None):
self.train_loss = AverageMeter()
self.embedding = embedding
self.network = DNetwork(opt, embedding)
self.optimizer = optim.SGD(parameters)
def adversarial_loss(self, batch, loss, embedding, y):
self.optimizer.zero_grad()
loss.backward(retain_graph=True)
grad = embedding.grad
grad.detach_()
perturb = F.normalize(grad, p=2)* 0.5
self.optimizer.zero_grad()
adv_embedding = embedding + perturb
network_temp = DNetwork(self.opt, adv_embedding) # This is how to get F(X)
network_temp.training = False
network_temp.cuda()
start, end, _ = network_temp(batch) # This is how to get F(X)
del network_temp # I even deleted this instance.
return F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
def update(self, batch):
self.network.train()
start, end, pred = self.network(batch)
loss = F.cross_entropy(start, y[0]) + F.cross_entropy(end, y[1])
loss_adv = self.adversarial_loss(batch, loss, self.network.lexicon_encoder.embedding.weight, y)
loss_total = loss + loss_adv
self.optimizer.zero_grad()
loss_total.backward()
self.optimizer.step()
I have few questions:
-
I substituted
tf.stop_gradient
withgrad.detach_()
. Is this correct? -
I was getting
"RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify
retain_graph=Truewhen calling backward the first time."
so I addedretain_graph=True
at theloss.backward.
That specific error went away. However now I’m getting a memory error after few epochs(RuntimeError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1525909934016/work/aten/src/THC/generic/THCStorage.cu:58 )
. I suspect I’m unnecessarily retaining graph.
Can someone @albanD let me know pytorch’s best practice on this? Any hint / even short comment will be highly appreciated.