Changing model value based on epoch during training


I want to change the forward behavior based on the epoch. I can’t figure out how to pass either the epoch or a simple True/False to the network during training.

Should I make it a parameter? If so can I change that in the training loop?

I did find a post that shows how to pass the value to forward, but if it’s a net module then I’m still defining the layers so I don’t see how to pass that value to the layers from the forward pass.


I’m not sure, if this would fit your use case, but you could directly manipulate a model attribute via:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(1, 1)
        self.fc2 = nn.Linear(1, 1)
        self.switch = 0
    def forward(self, x):
        if self.switch == 0:
            print('using fc1')
            x = self.fc1(x)
        elif self.switch == 1:
            print('using fc2')
            x = self.fc2(x)
            print('returning x')
        return x

model = MyModel()

for epoch in range(3):
    x = torch.randn(1, 1)
    model.switch = epoch
    out = model(x)

> using fc1
using fc2
returning x

That works for now thanks! It’s not ideal cause I have to share weights that way but it keeps me moving.