Can hook remove itself?

After it runs can hook remove itself from inside?

If you register a hook, you should get a handle to this hook, which can be used to remove it.

Yes, this handle is OK, but what I actually wanted to check if it is possible to do it from the function itself. Some easy and smart way.

Because usually this is what I need for the hook to run once.

You could try to access the handle from inside the hook:

def hook(grad):
    print(grad)
    handle.remove()
    
model = nn.Linear(1, 1)
handle = model.weight.register_hook(hook)

model(torch.randn(1, 1)).backward()
model(torch.randn(1, 1)).backward()

This code will print the gradient only once and remove the hook.

5 Likes

How can I remove the hooks if they are like this?

def register_hooks(self):
        """Register forward and backward hook to Conv module."""
        for module, name in self.conv_names.items():
            module.register_forward_hook(self.save_input_forward_hook)
            module.register_full_backward_hook(self.compute_fisher_backward_hook)

self.register_hooks()

You could store the handle, which is returned by the register..._hook methods, and remove it when needed.

I did the following

def register_hooks(self):
"""Register forward and backward hook to Conv module."""
    self.forward_hooks, self.backward_hooks= [], []
    for module, name in self.conv_names.items():
        self.forward_hooks.append(module.register_forward_hook(self.save_input_forward_hook))
        self.backward_hooks.append(module.register_full_backward_hook(self.compute_fisher_backward_hook)

def remove_hooks(self):
    for fh in self.forward_hooks:
        fh.remove()
    for bh in self.backward_hooks:
        bh.remove() 

Then I remove hooks and try to save the model

torch.save(deepcopy(model).half(), path)

and I have the following error

AttributeError: 'Conv2d' object has no attribute 'modified_forward'

I will be grateful if you help me! :slight_smile:

Could you post an executable code snippet, which would reproduce the issue, please?

Sorry, I was wrong. I can save the model. But then I try to load this model. And I have this error.
There is code snippet, where model is loading.

    if RANK in [-1, 0]:
        LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
        for f in last, best:
            if f.exists():
                strip_optimizer(f)  # strip optimizers
                if f is best:
                    LOGGER.info(f'\nValidating {f}...')
                    results, _, _ = val.run(data_dict,
                                            batch_size=batch_size // WORLD_SIZE * 2,
                                            imgsz=imgsz,
                                            model=attempt_load(f, device).half(),
                                            iou_thres=0.65 if is_coco else 0.60,  # best pycocotools results at 0.65
                                            single_cls=single_cls,
                                            dataloader=val_loader,
                                            save_dir=save_dir,
                                            save_json=is_coco,
                                            verbose=True,
                                            plots=True,
                                            callbacks=callbacks,
                                            compute_loss=compute_loss)  
def strip_optimizer(f='best.pt', s=''):  # from utils.general import *; strip_optimizer()
    # Strip optimizer from 'f' to finalize training, optionally save as 's'
    x = torch.load(f, map_location=torch.device('cpu'))
    if x.get('ema'):
        x['model'] = x['ema']  # replace model with ema
    for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates':  # keys
        x[k] = None
    x['epoch'] = -1
    x['model'].half()  # to FP16
    for p in x['model'].parameters():
        p.requires_grad = False
    torch.save(x, s or f)
    mb = os.path.getsize(s or f) / 1E6  # filesize
    print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
Traceback (most recent call last):
  File "train.py", line 651, in <module>
    main(opt)
  File "train.py", line 548, in main
    train(opt.hyp, opt, device, callbacks)
  File "train.py", line 437, in train
    strip_optimizer(f)  # strip optimizers
  File "/home/apleshkova/yolov5/utils/general.py", line 703, in strip_optimizer
    x = torch.load(f, map_location=torch.device('cpu'))
  File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/serialization.py", line 607, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/serialization.py", line 882, in _load
    result = unpickler.load()
  File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'Conv2d' object has no attribute 'modified_forward'

Unfortunately, the code is not executable so I can’t debug it but can give my best guesses.
It seems you are storing the entire model instead of the state_dict, which I would not recommend as I’ve seen it failing in various ways.
Store the state_dict instead, recreate the model instance, and load the state_dict back afterwards.

Thank you for your advice! This solved my problem)