Access epoch number from within the model's forward pass

I apologize if this question has already been covered, I couldn’t find much resources on this topic. I am trying to execute some block of code in my model’s forward pass, but only after a certain number of epochs has been reached. From my understanding there is no straightforward way of achieving this (e.g. this GitHub issue actually describes exactly what I’m trying to do, but led to nowhere, unfortunately).

Any suggestions on how to achieve this ?

Thank you
Darius

I guess it is doable.

Consider an example :

class MyModel(torch.nn.Module):
    def __init__(self):
         super(MyModel, self).__init__()
         self.current_epoch = 0

    def update_epoch(self, epoch):
         self.current_epoch = epoch

    def forward(self, x):
         if some_condition(self.current_epoch):
             ...
         else:
             ...

Before every epoch of your training, you can update epoch inside the model and that’s it :slight_smile:

for epoch in range(1, epochs + 1):
     model.update_epoch(epoch)
     for data, labels in dataloader:
         ...

Thanks, this did the trick! Fairly clear and elegant solution!