Question on passing parameters to forward() method

I have already build my network and now I would like to pass input to network. My forward() method looks:

def forward(self, input, index=None):
    if index is not None:
        #Do something
    else:
        #Do something else

And the way I pass input to network is:

output = net(input, 1)

And this leads to problem as index is still ‘None’ even I pass value ‘1’ to network.
After searching online, it seems that forward() method can only accept input in Variable type.
Therefore it seems alright after I modified the codes as following:

output = net(input, Variables(torch.IntTensor([1])))

However, is there any simpler way of doing this? Thank you for your time.

2 Likes

Hi @henrych4.
I’ve the same problem but I think it’s due the fact that internally, PyTorch calls _ call _ that in turns calls forward with only input parameter, so index takes default value None.

The simplest solution I found was to add a set_index(self, index) function to my network and propagate it through the modules.

If you found a better solution, please share :sunglasses:

1 Like