Forward hook is not always called

Hi I am using some forward hook, and apply the model on the data. It seems that during the forward, for the last batch (which has less data than the batch size), the forward hook is not always called. I do not understand the reason.

The code for hook is like this:

mapping_hooks = []
model_wrapper.range_dict = {}
def get_input_range(mem, name):
  def get_input_range_hook(module, input, output):
    print('HOOK')
    ...
    mem[name + f'{module.bar} bar' + str(output.device)] = cal_range
  return get_input_range_hook

for name_, module_ in model_wrapper.named_modules():
  mapping_hooks.append(module_.register_forward_hook(get_input_range(model_wrapper.range_dict, name_)))

with torch.no_grad():
  for input in data:
    for bar in bar_list:
      model_wrapper.apply(lambda m: setattr(m, 'foo', bar))
      model_wrapper(input)
      range_dict = {**model_wrapper.range_dict}
      for k, v in range_dict.items():
        print('bar: ', bar)
        print('v: ', v.shape)

for hook in mapping_hooks:
  hook.remove()
mapping_hooks = []

And in the final log, ‘HOOK’ is printed for every batch, except for the last batch and the second bar value (it is correctly printed for the first bar value for the last batch). The shape of v printed is also incorrect for the second bar value for the last batch, which should be [80, 3, 32, 32], but is [128, 3, 32, 32] (128 is the batch size, 80 is the number of data for the last batch, for cifar10 with 50k images and batch size of 128). I only use one GPU, and model_wrapper is an nn.DataParallel instance.

Thanks a lot.

First update: it seems that hook is called only once for every batch, only for the first bar value.

Second update: I update the code a little to be more like the real code (add ‘bar’ information and device information into the keys for the range_dict). Also, it seems that the original description is correct, i.e., the hook is called correctly for every bar value and every batch, except for the last batch, where the hook is only called for the first bar value.

Sorry I found the problem.