Some problem with hook

Hi everyone,
I’m tring to visualizing the feature map in my network. The code is as follow:

def hook_fn_backward(module, grad_input, grad_output):
    # print(module)
    # print('grad_output', grad_output)
    # print('grad_input', grad_input)

  PD2SE = PD2SEModel()

    # PD2SE = torch.load("./models/55.pth")
    modules = PD2SE.named_children()
    test = list(modules)
    for module in test[0][1]:
        if isinstance(module, nn.ReLU):

By doing this, it exactally could output the feature map as I want(maybe not?) But I don’t know how to close it, it takes CUDA out of memory error. And I’ve read some documents about this, which says I should remove the hook. So my question is where should I remove it. Indeed, I think I could set a function save the feature map in local memory. Then I could close it since the code use the hook again. Looking forward your reply.
@albanD @ptrblck


Just as a note, check the doc for register backward hook. In particular the warning that it might not return what you want in some cases at the moment.

The memory problem here is that you do total_grad_input.append(grad_input). So that list will contain more and more Tensors. You want to make sure to remove them when you don’t need them anymore.

Thanks for your reply. So I just need remove some data in that list I won’t use then this error will be done. if so, maybe I could save the data in my pc local memery then remove them all? Maybe I should try just use the hook with Tensor. After I change it, I will put the result there. And sorry to ask this stupid question. If you have free time, could you check this code which I inplement the Focal Loss. Since I use the loss, the value of loss return just like -2000, -1980…

class FocalLoss(nn.Module):

    def __init__(self, alpha=0, gamma=0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.eps = eps
        self.ce = nn.CrossEntropyLoss()

    def forward(self, input, target):
        logit = self.ce(input, target)
        p = torch.exp(-logit)

        if self.alpha:
            loss = -1 * self.alpha * (1 - p) ** self.gamma * logit
            loss = torch.clamp(-((1 - p) ** self.gamma) * logit, self.eps)  # focal loss

        return loss.mean()