Hello all.
This has confused me for sometime now. I am learning pytorch and for improving my understanding of how everything works, I’m experimenting with different sections of the framework.
Recently I though about creating custom operations and layers in pytorch.
The problem is, all of the examples I have seen so far had something to do with the input in one way or another.
But what if what I’m going to do does not have any direct interactions with the input data itself.
Suppose, I want to have a branch variable and let the network tune it. in order to, for example, add a layer or not.
I cant understand how I’m supposed to do it!
I know I should create a new parameter for example :
#and the new parameter!
self.mybranch = nn.Parameter(torch.zeros(1))
and add it to model parameters :
def register_hook(self):
self.register_parameter( 'mybranch' , self.mybranch )
print('parameters registered!')
and build my model for example:
self.conv1 = nn.Conv2d(3, 10, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv1_relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(10, 15, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv2_relu = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(15, 15, kernel_size=[3, 3], stride=(1, 1), padding=(1, 1))
self.conv3_relu = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(20, 10)
and then in forward I’d write :
def forward(self, x):
out = self.conv1(x)
out = self.conv1_relu(out)
out = self.conv2(out)
out = self.conv2_relu(out)
if (self.mybranch.item() <1 ) :
out = self.conv3(out)
out = self.conv3_relu(out)
out = self.linear1(out)
return out
before initiating the training loop, I register the parameter in my model :
model.register_hook()
and then inside my training loop I’d do :
...
output = model(input_var)
loss = criterion(output, target_var)
loss.add(model.mybranch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
but this doesn’t work and I believe its because mybranch
is not in the graph itself. Although I added my parameter to the model using register_parameter
hook, since there is no interaction between my variable/parameter with the input, I simply can’t get it in the graph! and thus make the loss affect it.
what should I do in such situations?