Control the behavior of neural network by condition


I am trying to develop an architecture that haves some computation that is only triggered by some external signal (for example, when the loss between two epochs does not change much).

However, I am not sure how I would implement this because of two reasons

  • If I write a conditional in the model definition that looks like following. I am not sure this conditional will be executed
class NeuralNet(nn.Module):
  # ...
  def forward(self, x):
    if <conditional>:
      # do something
    x = ...
    return x
  • I could not use hook function to get this done since my operation will affect both forward and backward pass, but there is a known bug of register_backward_hook.

Any input is appreciated!


Yes conditionals in the the forward function are executed since PyTorch builds the computation graph dynamically (at runtime) from the operations in the forward function. For example, here is the forward function of a BasicBlock (code) in PyTorch’s implementation of ResNet:

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

    return out

Thank you for your reply. I think this is a major advantage of PyTorch over other TF.

Then here comes the second question.

In the implementation of ResNet, the downsample variable is passed to BasicBlock while initialization (what I call “internal” signal). However, in my use case, the conditional is activated by some “external” signal like number of epochs, so how do I pass this external signal to forward pass.

There are two thoughts

  • Declare some global variables using global VAR and let forward function to access it.
  • Directly pass VAR through forward, which makes signature look like
def forward(self, x, VAR):
   if VAR:
      # do something

So which one is correct and preferred?

I think both can be correct since they will achieve the same results, but I would prefer the second approach. I do not like to use global variables and they are generally seen as bad practice. Maybe someone more knowledgeable will chime in and give you a different answer, that’s just my two cents :slight_smile: