This answer is pretty much all you need!
In the SGD example of the answer, you would only need to change the model.base
by your_model_name
and model.classifier
by your_loss_name
. If you wrote your loss module properly (with registered nn.Parameter
s and not just tensors), it should work.