Is_leaf is True and requires_grad is True, but grad is None

Hi,

I want to visualize convolutional features, but when I call backward(),
the input variable’s grad still None.

I have checked the input variable, and is_leaf is True.

class SaveFeatures():
    def __init__(self, module):
        self.hook = module.register_forward_hook(self.hook_fn)
    def hook_fn(self, module, input, output):
        self.features = output.clone().detach().requires_grad_(True)
    def close(self):
        self.hook.remove()

class FilterVisualizer():
    def __init__(self, model):
        self.model = model
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False

    def visualize(self, layer, filter, lr=0.1, opt_steps=100):
        img = np.uint8(np.random.uniform(150, 180, (48, 48, 3))) # generate random image #48, 48, 3
        img = transform_train(Image.fromarray(img)) # 3, 44, 44
        img = torch.unsqueeze(img, 0) # 1, 3, 44, 44
        activations = SaveFeatures(list(self.model.children())[0][layer])  # register hook
        img_var = Variable(img, requires_grad=True)  # convert image to Variable that requires grad
        optimizer = torch.optim.Adam([img_var], lr=lr, weight_decay=1e-6)
        for n in range(opt_steps):  # optimize pixel values for opt_steps times
            optimizer.zero_grad()
            self.model(img_var)
            loss = -activations.features[0, filter].mean() #activations.features.shape=1,64,44,44 ,features[0,filter] = 44,44
            loss.backward()
            optimizer.step()

Hello,
There are two points in your snippet do not make sense to me.

I have got that you want to extract features from intermediate layers, I think the first index [0] already represent the layer index, so what does the [layer] do for?

Here the optimizer will only update the input, it seems weird, is there a specific purpose that you only want to “update the input”?

And I have reproduced your issue with a dummy ConvNet, I think the problem raises in this line

    def hook_fn(self, module, input, output):
        self.features = output.clone().detach().requires_grad_(True)

You should remove the .detach() so that the input.grad and model.module.weight.grad are not None.

Thanks for your reply,

First,

print(list(self.model.children()))
[Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
   .....
  ), Linear(in_features=512, out_features=7, bias=True)]

For two term in the list, it need to choose the Sequential term.

Second, the optimizer only update the input because I don’t want to train the model.
I want to visualize convolutional features, and train the input that can activate the selected filter.
reference

I remove .detach() and it work. Thanks a lot!

1 Like