I have a network
class Net(nn.Module)
and two different weights w0
and w1
(concatenate weights of all layers into a vector). Now I want to optimize the network on the line connecting w0
and w1
, which means that the weight will have the form theta * w0 + (1-theta) * w1
. So now the parameter I want to optimize is no long the weight itself, but the theta
.
How can I implement this? In Pytorch, how can I define the parameter to be theta
, and set the weight to be form I want. To be specific, if I create a new class
NetOnLine(nn.Module)
how should I write the forward(self, X)
function?