"Learning to Generalize" - how to implement this peculiar loss function?

Hi guys,
I’m trying to implement the paper “Learning to Generalize: Meta-Learning for Domain Generalization” https://arxiv.org/pdf/1710.03463.pdf. The method itself doesn’t seem hard to understand; but I’m stuck on how to implement it in Pytorch.


let say we have 2 training set: the method is first calculating the loss and the gradients on a batch from first training set;
then is calculating the loss on a batch from second training set, w.r.t. the updated parameters;
and finally is combining the two gradients updates calculating the second gradient, on the updated parameters, w.r.t the original parameters tetha.
As pointed up in the paper, this means calculating the 2nd order gradient of the function w.r.t network parameters.
The method can be summarized also in the objective function repoted in the image.
I don’t know how to implement it; the objective function seems to be the easiest way, but whenever I call .backward() on the loss, the gradient are automatically calculated w.r.t the network parameters (tetha) while I need to calculate the gradients w.r.t tetha - alpha* F'(tetha)
If on the other side I try to implement it working on manually performing the update of network weights
tetha = tetha - gamma* (…)
I need to manually calculate the gradient of G() w.r.t. tetha - alpha* F'(tetha)
What I would like to do, in a sort of code / pseudocode that does not work now, would be:

F = CrossEntropyLoss(output1,label1)
F_gradient = grad(F, network.parameters(), create_graph = True)

G = CrossEntropyLoss(output2,label2) 
G_gradient = grad(G, network.parameters() - alpha * F_gradient.parameters(), create_graph = True)

for tetha in network.parameters():
    tetha.data = tetha.data - gamma * (F_gradient + beta * G_gradient)

Or, on the other side, working on the loss, I would like to do something like:

F = CrossEntropyLoss(output1,label1)
F.backward() # gradient calculated w.r.t tetha
G = CrossEntropyLoss(output2,label2) 
##save the calculated .backward() gradients for later update
###some shenigans###
G.backward() #evaluated on tetha - alpha*F'(tetha)
##put together the two gradients and update.

Do you have any suggestion on how to proceed?

Have you solved the problem? I’m facing similar problem now…

Yes I have. You can find on the web a small demo made by the author that let understand how to properly implement their loss function in pytorch.

Unfortunately the author of MLDG removed the code from github.
Anyone has a copy?

Thx,
D

why remove previous repo? Maybe He found his experiment had potential errors?

I emailed him and he said he wanted to remove redundant commits and will return the repo soon