Query regarding register_backward_hook

I’m working on CNN saliency maps.

def saliency_map_general(model, input, label):
    if (isinstance(model, torchvision.models.Inception3)):
        input = preprocess_inception(input)
    else:
        input = preprocess(input)

    output = model.forward(input)
    model.zero_grad()
                                       
    output[0][label].backward()

    grads = input.grad.data.clamp(min=0)
    grads.squeeze_()
    grads.transpose_(0,1)
    grads.transpose_(1,2)
    grads = np.amax(grads.cpu().numpy(), axis=2)
    
    true_image = input.data
    true_image = true_image.squeeze()
    true_image = true_image.transpose(0,1)
    true_image = true_image.transpose(1,2)
    true_image = deprocess(true_image)

    fig = plt.figure()
    plt.rcParams["figure.figsize"] = (20, 20)


    a = fig.add_subplot(1,2,1)
    imgplot = plt.imshow(true_image)
    plt.title('Original Image')
    plt.axis('off') 

    a = fig.add_subplot(1,2,2)
    imgplot = plt.imshow(grads, cmap='hot')
    plt.axis('off') 
    plt.title('Saliency Map')
    
    return grads

Here, I’m finding the derivative of the output “true class” w.r.t. input pixels.

Now, I want to use ‘guided backpropagation’ i.e. taking max(grad, 0) at each layer before passing the gradient to previous layer during backward pass. The documentation of register_backward_hook is not very detailed.

I want something like

def hookfunc(model, gi, go):
    grads = gi > 0
    return grads

h = model.register_backward_hook(hookfunc)

I tried doing this, but this hook callback function never runs. I think that is because I’m running .backward() on an element of the output Variable, and not on the model. I’m not sure about this, but the documentation isn’t making anything clear.

To clarify, I want the hook callback function to run for every layer during back pass. Any help on how to do this would be appreciated.

Where do you link your hook to your layer ? (I am also working on saliency map :wink: )

I don’t think the hook will work as expected (see this recent thread for a discussion of the values of the hook Exact meaning of grad_input and grad_output).
If you write the model yourself, you could just use hidden.register_hook(lambda grad: grad.clamp(min=0)), similar to the gradient clipping discussed here) on your activations between the layers.

Best regards

Thomas

1 Like

First of all, thanks for the awesome explanation of gradInput and gradOutput. I was trying to figure out what they were, but couldn’t find it anywhere in the docs.

I’ve found a workaround for this. It’s not an elegant solution, but it works.

def saliency_map_general(model, input, label):
    if (isinstance(model, torchvision.models.Inception3)):
        input = preprocess_inception(input)
    else:
        input = preprocess(input)
    
    h = [0]*len(list(model.modules()))

    def hookfunc(module, gradInput, gradOutput):
        print('hook callback is running')
        // do something here
    for j, i in enumerate(list(model.modules())):
        h[j] = i.register_backward_hook(hookfunc)
    output = model.forward(input)
    model.zero_grad()
                                        
    output[0][label].backward()
                                        
    for i in range(len(list(model.modules()))):
        h[i].remove()
    grads = input.grad.data.clamp(min=0)
    grads.squeeze_()
    grads.transpose_(0,1)
    grads.transpose_(1,2)
    grads = np.amax(grads.cpu().numpy(), axis=2)
    
    return grads

Can someone guide me as to how I should clamp the gradients to 0? (where I wrote “// do something here” in the code).

gradInput and gradOutput are both tuples, so I’m not able to do any operations on them.

I cannot convert gradInput to type “list” because it takes way too much time to do that (I grew impatient and restarted the jupyter notebook after a few minutes).

Does return [(None if g is None else g.clamp(min=0)) for g in gradInput] look vaguely good? (I didn’t try, but I would hope the gradients are either None or a variable.)

Best regards

Thomas

1 Like

Was getting error:
expected tuple, but hook returned ‘list’

This works without error though.

def hookfunc(module, gradInput, gradOutput):
            return tuple([(None if g is None else g.clamp(min=0)) for g in gradInput])

Thanks for the help!

Finally got it right! Thanks @tom

1 Like

Looks amazing, was this the paper by the VGG group? Do you plan on doing object localisation also?

Guided backprop was introduced in this paper: https://arxiv.org/abs/1412.6806

Yes, I’m doing a survey paper on weakly supervised localization techniques. So, this’ll be a part of that.

2 Likes

Do you have a github repo for this? Would be nice to have that to look into. Thanks!

Here it is. It’s not complete yet, still a lot of work to do.

1 Like