How to block gradient of ResNet shortcut?

I am trying to block the gradient of the ResNet shortcut (i.e., residual in the following code, when self.downsample is None, which means residual = x ). If the self.downsample is not None, I can register a backward hook function block_grad(self, grad_input, grad_output) to self.downsample layer and change the gradient of grad_input[0] = 0. However, when self.downsample is None, I cannot register backward hook function to the layer anymore. So, my question is how to block the gradient in this case?

   class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

Why can’t you just .detach() the variable you don’t want grad to flow through?

I think I did not address the problem clearly. What I wanted to do is similar to guided backpropagation, which means I want to do a backward pass on a trained model. During the backward pass, I would like to block some skip paths. Then, I do not know how to block the residual.

If you don’t want gradient flow back through residual. Instead of doing out += residual, you may do out += residual.detach(). I’m not sure if I understand your needs clearly though…

@SimonW If I do out += residual.detach(), it will block the gradient during training process. What I want is, block the gradient during test time only, which looks like:

pred = ResNet(input_variable)
ce = nn.CrossEntropyLoss()
loss = ce(pred, GT_label)

In the loss.backward(), I would like to block the gradient from residual. For self.downsample is not None, I can register a backward hook founction to change the gradient. However, if self.downsample is None, I cannot do that. Then, how I can still block the gradient from residual

Then just do different things based on You can set that with module.eval() or module.train().

@SimonW Thanks for your help! Even though it is not what I exactly asked, you did provide a cue for me to solve my problem. Thanks:)

1 Like