Can't access individual Modules/layers of a Sequential network

Hey :wave: ,

I had created a Unet with custom encoder from segmentation-models-pytorch. That custom encoder consisted of Modules of its own (which was moved from another network). The result is something like,

Unet(
  (encoder): Comma_Encoder()
  (decoder): UnetDecoder(
    (center): Identity()
    (blocks): ModuleList(
      (0): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(131, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
      (1): DecoderBlock(
        (conv1): Conv2dReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention1): Attention(
          (attention): Identity()
        )
        (conv2): Conv2dReLU(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (attention2): Attention(
          (attention): Identity()
        )
      )
    )
  )
  (segmentation_head): SegmentationHead(
    (0): Conv2d(32, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ConvTranspose2d(256, 6, kernel_size=(4, 4), stride=(4, 4))
    (2): Activation(
      (activation): Identity()
    )
  )
)

However, if I try to access layers of the encoder by performing model.encoder, I am getting nothing nothing :frowning:

next(model.encoder.named_modules())
# ('', Comma_Encoder())

[_ for _ in model.encoder.children()]
# []

Whic is pretty weird since I can forward pass through it easily, and train it correctly too :person_shrugging: which pretty much implied I just can’t access its submodules.

This may be something with the Unet class, but its inherited from nn.Module too AFAIK (demonstrated here) - and I was able to access the layers of the encoder perfectly before plugging it into SMP…

Is there some sort of limit to the recurrence of some Module? I wanted to utilize accessing individual layers for skip connections which are vital for segmentation :disappointed:

Any help is appreciated… :heart:

I guess Comma_Encoder just doesn’t contain any submodules.
Could you post the class definition here as I cannot find it in the linked repository?

sure - the encoder is extracted from an encoder-decoder style architecture (specifically, VQ-VAE-2) as demonstrated in this repo.

It’s just a bunch of 20 such ResBlocks connected together in 2 ‘submodules’ (encoder_1, and encoder_2) = 40 blocks connected in series. Everything inherits from Module so I wouldn’t think they wouldn’t be accessible :person_shrugging:

class Encoder(nn.Module):
    def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
        super().__init__()

        if stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 3, padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 3, padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)

^ Just basically 2 of these in series make up Comma_Encoder.

Additionally, I am able to view Comma_Encoder's module when I print it out (before plugging into the Unet) and get exactly what I expected.

I think I may have a clue here - this is how I extract the first few layers from my base model to construct the Encoder:

VQVAE = import_mod.VQVAE(**self.base_args).to(self.device)
VQVAE.load_state_dict(self.state_dict)

newmodel = torch.nn.Sequential(*(list(VQVAE.children())))

return [*newmodel[:2].to(self.device), torch.nn.Conv2d(128, 3, 2, 2).to(self.device)]

It just occured to me - is that a problem due to instantiating Sequential? and in that case, is there any workaround for me to access individual modules while still performing surgery on the model…

Thanks a ton for your immediate response :rocket: Pytorch forums is really a better place with all its active contributors so willing to help :slight_smile:

Thanks for the update.
The issue would come from returning a plain Python list containing the modules, which won’t register them properly in the parent module.
Use nn.ModuleList instead:

return nn.ModuleList([*newmodel[:2].to(self.device), torch.nn.Conv2d(128, 3, 2, 2).to(self.device)])

that still doesn’t seem to fix it unfortunately, when iterating over the encoders named_modules() I still get this

[('', Comma_Encoder())]

which is weirdly nested - as soon as I iterate named_modules in Comma_Encoder() I get the same thing again :man_shrugging: weird… :thinking:

Could you post a minimal, executable code snippet to reproduce the issue, please?

Sorry for the late reply, had gone for a vacay…

I was messing around to start some debugging, but apparently If I simple initialize the basemodel in the __init__, everything works as intended:

        ....
        self.VQVAE = import_mod.VQVAE(**self.base_args).to(self.device)

    def get_stages(self):
        self.VQVAE.load_state_dict(self.state_dict)
        self.newmodel = torch.nn.Sequential(*(list(VQVAE.children())))
        
        return torch.nn.ModuleList([*newmodel[:2].to(self.device), torch.nn.Conv2d(128, 3, 2, 2).to(self.device)])

before, everything was under get_stages() to keep things clean.

I’ve no idea how that works, but perhaps you could hazard a guess why such behavior happens for anyone else having trouble?

Both approaches work for me:

class MyModel1(nn.Module):
    def __init__(self):
        super().__init__()
        self.module1 = nn.Linear(1, 1)
        
class MyModel2(nn.Module):
    def __init__(self):
        super().__init__()
        self.get_stages()
        
    def get_stages(self):
        self.module2 = nn.Linear(1, 1)
        
model1 = MyModel1()
print(dict(model1.named_modules()))
# {'': MyModel1(
#   (module1): Linear(in_features=1, out_features=1, bias=True)
# ), 'module1': Linear(in_features=1, out_features=1, bias=True)}

model2 = MyModel2()
print(dict(model2.named_modules()))
# {'': MyModel2(
#   (module2): Linear(in_features=1, out_features=1, bias=True)
# ), 'module2': Linear(in_features=1, out_features=1, bias=True)}

so I still don’t know what’s causing the issue and can’t add any warnings for it :wink:

1 Like

its probably due to the convoluted mess of modules I’ve made :sweat_smile: Thanks a lot for your help! <3