The network is actually divided in two: 1) a memory side that is trained on a self supervised task, where the task is generated by two learned projections. 2) the language model.
I’m having problems implementing it, as pytorch mentions that the weights of (1) were changed. I’m not sure about what I should look as a possible fix:
How I can instruct PyTorch to create two backpropagation graphs? Freeze weights and do several passes?
Unsure if I understand the issue correctly, but I assume PyTorch fails during the backward pass claiming some parameters were already updated? Depending on the logic you want to apply to train your model you might need to detach the forward pass if you don’t want to backpropagate through the first part of your model anymore.
DCGAN might align with some of what you’re attempting to do and could be worth a look.
With a DCGAN, you have two networks, a discriminator and a generator. It may help motivate how to properly handle the detach in your training pipeline. Here is a tutorial: