After it runs can hook remove itself from inside?
If you register a hook, you should get a handle to this hook, which can be used to remove it.
Yes, this handle is OK, but what I actually wanted to check if it is possible to do it from the function itself. Some easy and smart way.
Because usually this is what I need for the hook to run once.
You could try to access the handle
from inside the hook:
def hook(grad):
print(grad)
handle.remove()
model = nn.Linear(1, 1)
handle = model.weight.register_hook(hook)
model(torch.randn(1, 1)).backward()
model(torch.randn(1, 1)).backward()
This code will print the gradient only once and remove the hook.
How can I remove the hooks if they are like this?
def register_hooks(self):
"""Register forward and backward hook to Conv module."""
for module, name in self.conv_names.items():
module.register_forward_hook(self.save_input_forward_hook)
module.register_full_backward_hook(self.compute_fisher_backward_hook)
self.register_hooks()
You could store the handle
, which is returned by the register..._hook
methods, and remove it when needed.
I did the following
def register_hooks(self):
"""Register forward and backward hook to Conv module."""
self.forward_hooks, self.backward_hooks= [], []
for module, name in self.conv_names.items():
self.forward_hooks.append(module.register_forward_hook(self.save_input_forward_hook))
self.backward_hooks.append(module.register_full_backward_hook(self.compute_fisher_backward_hook)
def remove_hooks(self):
for fh in self.forward_hooks:
fh.remove()
for bh in self.backward_hooks:
bh.remove()
Then I remove hooks and try to save the model
torch.save(deepcopy(model).half(), path)
and I have the following error
AttributeError: 'Conv2d' object has no attribute 'modified_forward'
I will be grateful if you help me!
Could you post an executable code snippet, which would reproduce the issue, please?
Sorry, I was wrong. I can save the model. But then I try to load this model. And I have this error.
There is code snippet, where model is loading.
if RANK in [-1, 0]:
LOGGER.info(f'\n{epoch - start_epoch + 1} epochs completed in {(time.time() - t0) / 3600:.3f} hours.')
for f in last, best:
if f.exists():
strip_optimizer(f) # strip optimizers
if f is best:
LOGGER.info(f'\nValidating {f}...')
results, _, _ = val.run(data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
model=attempt_load(f, device).half(),
iou_thres=0.65 if is_coco else 0.60, # best pycocotools results at 0.65
single_cls=single_cls,
dataloader=val_loader,
save_dir=save_dir,
save_json=is_coco,
verbose=True,
plots=True,
callbacks=callbacks,
compute_loss=compute_loss)
def strip_optimizer(f='best.pt', s=''): # from utils.general import *; strip_optimizer()
# Strip optimizer from 'f' to finalize training, optionally save as 's'
x = torch.load(f, map_location=torch.device('cpu'))
if x.get('ema'):
x['model'] = x['ema'] # replace model with ema
for k in 'optimizer', 'training_results', 'wandb_id', 'ema', 'updates': # keys
x[k] = None
x['epoch'] = -1
x['model'].half() # to FP16
for p in x['model'].parameters():
p.requires_grad = False
torch.save(x, s or f)
mb = os.path.getsize(s or f) / 1E6 # filesize
print(f"Optimizer stripped from {f},{(' saved as %s,' % s) if s else ''} {mb:.1f}MB")
Traceback (most recent call last):
File "train.py", line 651, in <module>
main(opt)
File "train.py", line 548, in main
train(opt.hyp, opt, device, callbacks)
File "train.py", line 437, in train
strip_optimizer(f) # strip optimizers
File "/home/apleshkova/yolov5/utils/general.py", line 703, in strip_optimizer
x = torch.load(f, map_location=torch.device('cpu'))
File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/serialization.py", line 607, in load
return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/serialization.py", line 882, in _load
result = unpickler.load()
File "/home/apleshkova/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
type(self).__name__, name))
AttributeError: 'Conv2d' object has no attribute 'modified_forward'
Unfortunately, the code is not executable so I can’t debug it but can give my best guesses.
It seems you are storing the entire model instead of the state_dict
, which I would not recommend as I’ve seen it failing in various ways.
Store the state_dict
instead, recreate the model instance, and load the state_dict
back afterwards.
Thank you for your advice! This solved my problem)