Hook function that demands epoch

Hi. I have a problem with Pytorch Hook function,
as I mentioned in title, my Hook demands current epoch for recording some statistics that changed by every epoch.
Is there any some tricky method for solving this problem?

For example,

my_model = some_net()
my_model.specific_layer.register_forward_hook(tricky_hook_for_recording_epoch)

for epoch in range(epoch_size):
    for i, (data, label) in enumerate(dataloader):
        output = my_model(data)

I’m using TensorBoard for Pytorch and this is why I need to record epoch.
Thanks for any answer!

Hi @FruitVinegar

pytorch/ignite might be help you to achieve that easily.
Please look here for a description of how it works.

But if you don’t want to bother with ignite I guess you could also do this:

global epoch, prev_epoch

def tricky_hook_for_recording_epoch(module, input, output):
    if epoch > prev_epoch:
        # TensorBoard stats, logging, etc
        prev_epoch = epoch

my_model = some_net()
my_model.specific_layer.register_forward_hook(tricky_hook_for_recording_epoch)
prev_epoch = 0
for epoch in range(epoch_size):
    for i, (data, label) in enumerate(dataloader):
        output = my_model(data)

I was refactoring my spaghetti code with many parameters and libraries, so that Ignite looks helpful for me. Thank you for introducing that!