Derivative of two networks

I have network A and B.
the architecture is like this

x->A->A(x)->B->B(A(x))
input is x, first get A(x) then send to B get B(A(x))
the loss function is computed as a function of loss(B(A(x)), A(x))
The parameter in B is not allowed to do derivative

in pytorch, if i first do (say predict is A(x)
new_predict = Variable(predict.data, requires_grad=False)

then send new_predict to B, then do loss.back()
it kind of works

but, if i don’t do that, the learning just failed

Wonder why

Thanks

1 Like

I’m not sure it’s working in the way you want it to.
Your architecture looks like a GAN, where only the generator will be trained. If you detach your prediction (with .data) you will lose the gradients for A.
Could you check if you get valid gradients in A?

1 Like

yes, you are right, its a kind of GAN

Because loss is a function of both B(A(x)) and A(x), so, even if I detach the prediction of A(x) for B(A(x)), I still get gradients from A(x)… (the networks actually learn something by doing this)

What makes me confused is that if I don’t do the detach, the networks just learns nothing…

can you please suggest how to do the validation?

I have worked on something similar. Not sure if this is right, @ptrblck can confirm if it makes sense.

I have three networks embed, classifier and adversary. The embed+classifier network take a gradient step with the adversary fixed and the adversary takes a gradient step with the embed+classifier fixed
In my training loop I did something like

embed.train()
classifier.train()
embedding = embed(input)
output = classifier(embedding)
classifier_loss = classifier_criterion(output, target)

adversary.eval()
with torch.no_grad():
  sensitive_output = adversary(embedding)
adv_loss = adv_criterion(sensitive_output, sensitive_target)

loss = classifier_loss + adv_loss

classifier_optimizer.zero_grad()
embed_optimizer.zero_grad()
loss.backward()
classifier_optimizer.step()
embed_optimizer.step()

and similarly for the adversary gradient step I got my classifier_loss with embed and classifier set to eval() and with torch.no_grad().

1 Like

Could you explain your loss function a bit?
You are right. Even if you detach A(x) and pass it to B, you can still get gradients from A(x).

@viraat Could you check if adv_loss is used? Just define loss as

loss = adv_loss

and run your code. I think you will get a RuntimeError stating that your tensor does not require grad and does not have a grad_fn.
As classifier_loss is valid, you won’t get an error at the moment.

2 Likes

my loss is a KL(A(x)|| B(A(x))), B is well trained and fixed, and only A is needed training… (its called educational networks I think) could it because of KL is too complex to use autograd :sweat_smile:
thanks
My understanding is that to do derivative KL w.r.t. theta_A (write KL a function of A(x) and B(A(x)) to make it simpler)
even B is not allow derivative, we should still get derivative from two parts
something like ( sorry for the messy function) :
( d KL(A(x),B(A(x))/ dA ) * ( d A / d theta_A) + (d KL(A(x),B(A(x)) / d B) * (d B/ dA) * (dA / d theta_A)

currently, it seemed I have to give up the second term to make it work

@ptrblck you are right. I get the below error when I call .backward() only on the adv_loss

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-57-6c7629d37758> in <module>()
     13 
     14 for e in range(num_epochs):
---> 15     clsTrain_loss, clsTrain_acc, advTrain_loss, advTrain_acc = laftr_epoch(encoder, classifier, adversary, X_train, y_train, a_train, en_opt, cls_opt, adv_opt, cls_criterion, adv_criterion)
     16 
     17     if e % 10 == 0:

<ipython-input-53-1a2ea661e046> in laftr_epoch(encoder, classifier, adversary, X, y_cls, y_adv, opt_en, opt_cls, opt_adv, cls_criterion, adv_criterion, batch_size)
     33 #         cls_en_combinedLoss = cls_loss + adv_loss_fixed
     34         cls_en_combinedLoss = adv_loss_fixed
---> 35         cls_en_combinedLoss.backward()
     36 #         print(cls_en_combinedLoss.grad_fn)
     37         opt_cls.step()

~/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
     91                 products. Defaults to ``False``.
     92         """
---> 93         torch.autograd.backward(self, gradient, retain_graph, create_graph)
     94 
     95     def register_hook(self, hook):

~/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     88     Variable._execution_engine.run_backward(
     89         tensors, grad_tensors, retain_graph, create_graph,
---> 90         allow_unreachable=True)  # allow_unreachable flag
     91 
     92 

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Basically, I want my classifier+encoder to take a gradient step while the adversary is fixed minimising the combined loss. Then the adversary will take a step with theclassifier + encoder fixed maximising the combined loss. What would be the best way to go about achieving this?

My training loop code is below. Any comments on best practices for doing stuff like this is also greatly appreciated! Thanks.

def laftr_epoch(encoder, classifier, adversary, X, y_cls, y_adv, opt_en, opt_cls, opt_adv, cls_criterion, adv_criterion, batch_size=64):
    cls_en_combinedLosses = []
    cls_en_accs = []
    adv_combinedLosses = []
    adv_accs = []
    
    for beg_i in range(0, X.shape[0], batch_size):
        
        x_batch = X.iloc[beg_i:beg_i + batch_size].values
        y_cls_batch = y_cls[beg_i:beg_i + batch_size]
        y_adv_batch = y_adv[beg_i:beg_i + batch_size]        
        
        x_batch = torch.from_numpy(x_batch).to(device).float()
        y_cls_batch = torch.from_numpy(y_cls_batch).to(device).float()
        y_adv_batch = torch.from_numpy(y_adv_batch).to(device).float()
        
        # fix adversary take gradient step with classifier and encoder
        encoder.train()
        classifier.train()
        z = encoder(x_batch)
        y_hat = classifier(z)

        adversary.eval()
        with torch.no_grad():
            a_fixed = adversary(z)

        opt_cls.zero_grad()
        opt_en.zero_grad()
        
        cls_loss = cls_criterion(y_hat, y_cls_batch)
        adv_loss_fixed = adv_criterion(a_fixed, y_adv_batch, y_cls_batch)
        cls_en_combinedLoss = cls_loss + adv_loss_fixed
        cls_en_combinedLoss.backward()
        opt_cls.step()
        opt_en.step()
        
        # fix encoder and classifier and take gradient step with adversary
        encoder.eval()
        classifier.eval()
        with torch.no_grad():
            z_fixed = encoder(x_batch)
            y_hat_fixed = classifier(z_fixed)
        
        adversary.train()
        a_hat = adversary(z_fixed)
        
        opt_adv.zero_grad()
        
        cls_loss_fixed = cls_criterion(y_hat_fixed, y_cls_batch)
        adv_loss = adv_criterion(a_hat, y_adv_batch, y_cls_batch)
        
        adv_combinedLoss = -(cls_loss_fixed + adv_loss)
        adv_combinedLoss.backward()
        
        opt_adv.step()
        
        cls_en_combinedLosses.append(cls_en_combinedLoss.item())
        adv_combinedLosses.append(adv_combinedLoss.item())
        
        cls_preds = torch.round(y_hat.data).squeeze(1).numpy()
        cls_acc = sum(cls_preds == y_cls_batch).numpy()/len(y_cls_batch)
        cls_en_accs.append(cls_acc)
        
        adv_preds = torch.round(a_hat.data).squeeze(1).numpy()
        adv_acc = sum(adv_preds == y_adv_batch).numpy()/len(y_adv_batch)
        adv_accs.append(adv_acc)

    return np.mean(cls_en_combinedLosses), np.mean(cls_en_accs), np.mean(adv_combinedLosses), np.mean(adv_accs)

Running the same code without torch.no_grad() doesn’t throw an error.

What is the difference between running .eval() and with torch.no_grad() and combining both?

EDIT: From 'model.eval()' vs 'with torch.no_grad()' - #2 by albanD

  • model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval model instead of training mode.
  • torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script)

So, when I do a gradient step for the classifier+encoder I want the adversary to be fixed. The combined loss from this step should not affect the adversary’s weights when I do the adversary weight update later.

Would changing the model from train to eval have that effect?

No it would not. Switching to eval-mode would not prevent the model’s parameters from being updated in the optimizer. Instead you could freeze them via

for param in adversary.parameters():
    param.requires_grad = False

And setting requires_grad back to True if you want to update the adversary model.

1 Like

Thanks for letting me know.

Additionally, if I would like batchnorm and dropout layers to work in eval mode I have to call .eval() as well as set requires_grad = False for the weights of the model? And vice versa to .train() and requires_grad=True when I want the weights to be updated.

That’s right. And if you don’t need the performed steps tracked by autograd you can additionally use the torch.no_grad() context manager, since this will reduce the memory consumption.

To be clear, I want the combined_loss to be used for performing a gradient step on the classifier+encoder networks.
This is how I currently do it (after taking into account the comments above)

encoder.train()
classifier.train()
z = encoder(x_batch)
y_hat = classifier(z)

adversary.eval()
for param in adversary.parameters():
    param.requires_grad = False
a_fixed = adversary(z)

opt_cls.zero_grad()
opt_en.zero_grad()

cls_loss = cls_criterion(y_hat, y_cls_batch)
adv_loss_fixed = adv_criterion(a_fixed, y_adv_batch, y_cls_batch)
cls_en_combinedLoss = cls_loss + adv_loss_fixed
# want this to work only wrt classifier and encoder
cls_en_combinedLoss.backward()
opt_cls.step()  # classifier step
opt_en.step()   # encoder step

The optimizer steps for classifier and encoder should be with respect to the combined_loss not just the cls_loss.

1 Like

@Xiaofeng_Wu As far as I understand you are using B(A(x)) as a “fixed” target, i.e. B doesn’t need gradients and shouldn’t be trained. In my opinion, detaching should be alright, but I’m not sure if there are some pitfalls I’m missing using KLDiv.
If you don’t detach A(x) while passing it to B, could you check the gradients in A?

@viraat I think the code looks good. I just re-created a small dummy experiment and I get valid gradients for encoder and classifier, while adversary doesn’t have any gradients.
Could you please double check it by calling backward on both losses separately and checking the gradients of your models?

1 Like

@Xiaofeng_Wu Could you tell me why did you requires_grad=False in your code?

new_predict = Variable(predict.data, requires_grad=False)

Mathematically, it is not correct, because by doing fix grad, you got only part of gradients for theta_A
But physically, it is, because B is supposed to be a well trained network, and B(A(x)) is supposed to have a better performance than A(x). If you don’t do the fix, the other part of the gradients will try to make the B(A(x)) to be like A(x), and it is exactly not what we want.

1 Like