I am modifying ResNet to train some variables emb, which depends on some input meta data. I have issues passing the meta data down to my user defined layer in ResNet.
emb is Category Embedding in the diagram, meta is “Dress” in the diagram
-
Main function:
outputs = model(inputs)>> outputs = model(inputs, meta), model here is ResNet 18 -
ResNet forward function,
def forward(self, x, meta):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x, meta) -
BasicBlock forward function:
def forward(self, x, meta):
I am getting errors that don’t make sense to me. The above code gives me error message
x = self.layer1(x, meta)
File “/home/user/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
But if I remove the last argument meta and run this instead, x = self.layer1(x), then I get this error
Traceback (most recent call last):
File “main.py”, line 202, in
main()
File “main.py”, line 153, in main
outputs = model(inputs, meta)
File “/home/user/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
File “/…/models/resnet.py”, line 471, in forward
x = self.layer1(x)
File “/home/user/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
File “/home/user/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/container.py”, line 92, in forward
input = module(input)
File “/home/user/anaconda3/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 493, in call
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: ‘meta’
This is the code my mod is based on,