Can forward() in nn.module be override with different arguments?

Can forward() be override with different arguments? So different forward is used in different scenario?

example:

class Test(nn.Module):

def __init__(self):
   super(Test, self).__init__()
   pass

def forward(self, x, y ):
    pass

def forward(self, x, y , z):
    pass

Hi,

python does not allow you to create two functions with the same name like that.
What you can do though is:

def __init__(self):
   super(Test, self).__init__()
   pass

def forward(self, x, y, z=None ):
    if z is None:
        # Forward with 2 arguments
        pass
    else:
        # Forward with 3 arguments
        pass
5 Likes

Hello, based on that answer.
Can i have a forward pass when i am training(passing for example a bool to clarify that i am training) and a different forward pass when i am evaluating?
For example i am training an rnn and taking the last hidden state as an encoding for a sentence. But while i am training, together with my sentences i am passing some other information(for example negative samples), but i would like to load my trained model and just pass one sentence through the rnn and get its encoding.
Could i do that? @albanD @ptrblck

Sure!
You can adapt @albanD’s code and pass an additional flag to it, if that’s what you are looking for:

 
def forward(self, x, y, training=True):
    if training:
        pass
    else:
        pass

Also, if your forward method behavior switches based on the internal training status (model.train() vs. model.eval()), you don’t even have to pass an additional flag, but can just use the internal one:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 2)
        
    def forward(self, x):
        if self.training:
            return x
        else:
            return self.fc(x)

x = torch.randn(1, 10)
model = MyModel()
print(model(x).shape)
> torch.Size([1, 10])
model.eval()
print(model(x).shape)
> torch.Size([1, 2])
5 Likes

That is exactly what i was looking for. Thank you very much!

1 Like

Hi,
Would placing ‘if else’ statements within the forward pass slow it down appreciably during training?
Thanks