Can't visualise model structure when defining an architecture in this way

Hi there!

I have decided to formulate my model using the below class in order to be able to insert and test new modules (CBAM, DANet etc.) in a MobileNet backbone as well as be able to insert hooks to compute gradients to see where the attention of the model is focused at. However, when I init the model and I try and print it, it doesn’t work. Any ideas would be much appreciated!

from torchvision.models import mobilenet_v3_small, mobilenet_v3_large
from torchvision.models import MobileNet_V3_Small_Weights as small_weights, MobileNet_V3_Large_Weights as large_weights
import torch.nn.functional as F

class MBNV3_Grad(nn.Module):
def init(self, model, module, output_size):
super(MBNV3_Grad, self).init()

    self.model = model
    self.module = module
    self.output_size = output_size
    self.layers = [1, 4, 5, 6, 7, 8, 9, 10, 11] if self.model == "small" else [4, 5, 6, 11, 12, 13, 14, 15]

# Initialize the weights of the inserted module layers or the SE block
def _weights_init(self, m):
    torch.manual_seed(42)
    if isinstance(m, nn.Conv2d):
        torch.nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)

def set_grad(self):
    """_summary_: This function is used to set the gradient of the inserted module layers or the SE block.
    """
    #set all grads in the model to false
    for param in self.mbnv3.parameters():
        param.requires_grad = False

    #set the grad of the last layer to true
    for param in self.mbnv3.classifier[-1].parameters():
        param.requires_grad = True

    if self.model == "small":
        for param in self.mbnv3.features[1].block[1].parameters():
            param.requires_grad = True

        for i in self.layers[1:]:
            for param in self.mbnv3.features[i].block[2].parameters():
                param.requires_grad = True
    else:
        for i in self.layers:
            for param in self.mbnv3.features[i].block[2].parameters():
                param.requires_grad = True


    return self.mbnv3

def create_custom_model(self):
    """_summary_: This function is used to create a custom model by inserting the custom module """
    if self.model == "small":
        # setting the base model
        self.mbnv3 = mobilenet_v3_small(weights= small_weights.IMAGENET1K_V1)
        # setting the custom module in specific layers for small variant
        if self.model == "small" and self.layers[0] == 1:
            prev_out_channels = self.mbnv3.features[1].block[0].out_channels
            self.mbnv3.features[1].block[1] = self.module(prev_out_channels) if self.module != None else self.mbnv3.features[1].block[1]
            self.mbnv3.features[1].block[1].apply(self._weights_init)

        for i in self.layers[1:]:
            prev_out_channels = self.mbnv3.features[i].block[0].out_channels
            self.mbnv3.features[i].block[2] = self.module(prev_out_channels) if self.module != None else self.mbnv3.features[i].block[2]
            self.mbnv3.features[i].block[2].apply(self._weights_init)
    else:
        # setting the base model in specific layers for large variant
        self.mbnv3 = mobilenet_v3_large(weights= large_weights.DEFAULT)
        for i in self.layers:
            prev_out_channels = self.mbnv3.features[i].block[0].out_channels
            self.mbnv3.features[i].block[2] = self.module(prev_out_channels) if self.module != None else self.mbnv3.features[i].block[2]
            self.mbnv3.features[i].block[2].apply(self._weights_init)

    self.mbnv3.classifier[-1] = nn.Linear(self.mbnv3.classifier[-1].in_features, self.output_size)
    self.mbnv3.classifier[-1].apply(self._weights_init)

    self.set_grad()

    return self.mbnv3

def restructure_model(self):
    # dissect the network to access its last convolutional layer
    self.features_conv = self.mbnv3.features[:13] if self.model == "small" else self.mbnv3.features[:17]  
    # adaptive average pooling
    self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)  
    # get the classifier of the model
    self.classifier = self.mbnv3.classifier
    # placeholder for the gradients
    self.gradients = None

    # delete the original model
    del self.mbnv3

    return self.features_conv, self.avgpool, self.classifier

# hook for the gradients of the activations
def activations_hook(self, grad):
    self.gradients = grad
    
def forward(self, x):
    
    self.create_custom_model() # create the custom model
    self.restructure_model() # restructure the model
    
    x = self.features_conv(x) # extract the features

    # register the hook
    #set the last tensor to have requires_grad = True
    h = x.register_hook(self.activations_hook) # register the hook

    # adaptively average pool the features
    x = self.avgpool(x) # adaptive average pooling

    x = x.squeeze() # flatten the output of the adaptive average pooling
    x = self.classifier(x) # get the class probabilities from the classifier

    return x

# method for the gradient extraction
def get_activations_gradient(self):
    return self.gradients

# method for the activation exctraction
def get_activations(self, x):
    return self.features_conv(x)

Your are not initializing the internal modules in the __init__ method as would be the common case, but instead in the forward method. You will thus see all modules after the initial forward pass as seen here:

model = MBNV3_Grad(model="small", module=None, output_size=1)
print(model)
# MBNV3_Grad()

model(torch.randn(1, 3, 224, 224))
print(model)
# MBNV3_Grad(
#   (features_conv): Sequential(
#     (0): Conv2dNormActivation(
#       (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
#       (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
#       (2): Hardswish()
#     )
# ...

Ahhh of course, thanks so much for your help!