For illustration, I have condensed my code to the following snippet
import torch.nn as nn
import torch
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.func = nn.Linear(10, 5)
def forward(self, x, y):
print(y)
return self.func(x)
model = nn.Sequential(*[Model()])
x = y = torch.randn(size=(20, 10))
model(x, y)
I get following error - TypeError: forward() takes 2 positional arguments but 3 were given
This is because nn.Sequential has its own forward which is accepting only one argument. How can I change/extend forward’s definition so that it accepts two arguments. Obviously I can accept both x and y as a list under single argument but my use case would make it look messy.