Facing issue with forward method while combining two blocks

I have two model blocks, one is a pretrained resnet moel (resEncoder) and the other is a custom decoder model that I wrote. Now, to initialize the decoder model, x.size() is taken as a parameter which is called input_size.

Here, x is actually the output that I get from the layer before the avgpool layer from the resnet model. So, in forward pass, I want to add self.decoder and create an object of the decoderModel() class. But, to do that, I need the input the size of that tensor which I get from before the avgpool layer of resnet18. One way is to hardcode it, which I don’t wanna do. Now I can call it like this, but will it work in training if I don’t create self.decoder in init()?

class fullModel(nn.Module):
  def __init__(self, addDecoder: bool):
    super(fullModel, self).__init__()
    self.addDecoder = addDecoder 

    self.resEncoder = pretrainedModelBlock(model_name=params["model_name"], classes=params["num_classes"], addDecoder= self.addDecoder)
    
    # if self.addDecoder:
    #   self.decoder = decoderModel(input_size= beforeAvgPool.size(), depth = 5)

  def forward(self, x):
    if self.addDecoder:
      beforeAvgPool, pool, resnet_features, out = self.resEncoder(x)
      out = decoderModel(input_size= beforeAvgPool.size(), depth = 5)(beforeAvgPool)

    else:
      out = self.resEncoder(x)
    
    return out

Now it works this way, but since I do not have self.decoder in init of fullModel(), will it work if I create a model object from this fullModel class?

x = torch.randn((64, 3, 32, 32))
fullModel(addDecoder = True)(x).size()

Hi! Do you expect the output size of your encoder to change on the fly? If not, you can try passing a random input to the encoder in the constructor itself to get beforeAvgPool.size() and use that to init your decoder. Though I’m really not sure why you are taking this approach :thinking:

I can do that actually and that is what I meant by ‘hardcoding it’. Thing is, I want to use the same classes whether I use resnet18 or resnet50 or anything else. So, if I fix the output size there, the code will not be as flexible.