Hi,
I want to change the forward behavior based on the epoch. I can’t figure out how to pass either the epoch or a simple True/False to the network during training.
Should I make it a parameter? If so can I change that in the training loop?
I did find a post that shows how to pass the value to forward, but if it’s a net module then I’m still defining the layers so I don’t see how to pass that value to the layers from the forward pass.
Thanks
I’m not sure, if this would fit your use case, but you could directly manipulate a model attribute via:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(1, 1)
self.fc2 = nn.Linear(1, 1)
self.switch = 0
def forward(self, x):
if self.switch == 0:
print('using fc1')
x = self.fc1(x)
elif self.switch == 1:
print('using fc2')
x = self.fc2(x)
else:
print('returning x')
return x
model = MyModel()
for epoch in range(3):
x = torch.randn(1, 1)
model.switch = epoch
out = model(x)
> using fc1
using fc2
returning x
That works for now thanks! It’s not ideal cause I have to share weights that way but it keeps me moving.