How can I insert a branch variable in a model graph in pytorch?

Hello all.
This has confused me for sometime now. I am learning pytorch and for improving my understanding of how everything works, I’m experimenting with different sections of the framework.
Recently I though about creating custom operations and layers in pytorch.
The problem is, all of the examples I have seen so far had something to do with the input in one way or another.
But what if what I’m going to do does not have any direct interactions with the input data itself.
Suppose, I want to have a branch variable and let the network tune it. in order to, for example, add a layer or not.
I cant understand how I’m supposed to do it!
I know I should create a new parameter for example :

#and the new parameter!
self.mybranch = nn.Parameter(torch.zeros(1))

and add it to model parameters :

def register_hook(self):
    self.register_parameter( 'mybranch' , self.mybranch )
    print('parameters registered!')

and build my model for example:

self.conv1 = nn.Conv2d(3, 10, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv1_relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(10, 15, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv2_relu = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(15, 15, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv3_relu = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(20, 10)

and then in forward I’d write :

def forward(self, x):
   out = self.conv1(x)
   out = self.conv1_relu(out)
   out = self.conv2(out)
   out = self.conv2_relu(out)
   if (self.mybranch.item() <1 ) : 
       out = self.conv3(out)
       out = self.conv3_relu(out)
   out = self.linear1(out)
return out

before initiating the training loop, I register the parameter in my model :
and then inside my training loop I’d do :

output = model(input_var)
loss = criterion(output, target_var) 


but this doesn’t work and I believe its because mybranch is not in the graph itself. Although I added my parameter to the model using register_parameter hook, since there is no interaction between my variable/parameter with the input, I simply can’t get it in the graph! and thus make the loss affect it.

what should I do in such situations?


First of all you do not need to explicitly register the parameters. This is done for you automatically when you save an nn.Parameter as This means that you can remove your register_hook function completely and it won’t change anything.

The problem here is that the operations that you do with self.mybranch is not differentiable. Indeed the gradient will be 0 almost everywhere and infinite at the point where it’s equal to 1.
What you can do is to do a “smoother” version of this branching:

def forward(self, x):
   out = self.conv1(x)
   out = self.conv1_relu(out)
   out = self.conv2(out)
   out = self.conv2_relu(out)
   # Always compute the branch
   branch = self.conv3(out)
   branch = self.conv3_relu(branch)
   # Smooth choise between out and branch
   #  make choise variable between 0 and 1
   branch_choice = torch.sigmoid(self.mybranch)
   out = branch_choice * branch + (1 - branch_choice) * out
   out = self.linear1(out)
return out

Thanks a lot. I see it now :slight_smile:
by the way the smoothing part is not needed right? or is it ? since if I put that variable in that last formula,
it will be differentiable, and we should be able to apply gradients to it right?
Oh I got it :slight_smile: no need for the explanation here. thanks a gazillion times :slight_smile:

by the way do I still need to add the loss in the training loop myself or will it automatically get tuned?

Update 2 :
Got my second answer as well. there is no need to manually add their values to the loss in the training loop. everything is done automatically :slight_smile: