Torchsummary has a bug when used with(tabulare and image data)

Hi all.

I’ve been working on a fusion model with two image and metadata modalities, CNN (image) and MLP (metadata).

I tried merging two models using concatination in the same way that @ptrblck suggested, but the torchsummary doesn’t work with it. Though the model trains and works fine.

I’d like to see the parameters of the merged model and confirm that it also works properly in torchsummary. I believe the issue is with torchsummary rather than our configuration to the fusion model.

Here is my two modalilites followed by the ensambel model

MLP modality


import pandas as pd
import numpy as np
import torch 
from torch import nn
from torch.nn import init, Parameter
from torchsummary import summary
import torch.nn.functional as F
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"




class MaxNet(nn.Module):
    def __init__(self, input_dim=3, omic_dim=32, dropout_rate=0.25):
        super(MaxNet, self).__init__()
        hidden = [64, 48, 32, 32]

        encoder1 = nn.Sequential(
            nn.Linear(input_dim, hidden[0]),
            nn.ELU(),
            nn.AlphaDropout(p=dropout_rate, inplace=False))
        
        encoder2 = nn.Sequential(
            nn.Linear(hidden[0], hidden[1]),
            nn.ELU(),
            nn.AlphaDropout(p=dropout_rate, inplace=False))
        
        encoder3 = nn.Sequential(
            nn.Linear(hidden[1], hidden[2]),
            nn.ELU(),
            nn.AlphaDropout(p=dropout_rate, inplace=False))

        encoder4 = nn.Sequential(
            nn.Linear(hidden[2], omic_dim),
            nn.ELU(),
            nn.AlphaDropout(p=dropout_rate, inplace=False))
        
        self.encoder = nn.Sequential(encoder1, encoder2, encoder3, encoder4)


    def forward(self, x):
        features = self.encoder(x)
        return features

#model = MaxNet()
#print(model)

def model_SNN() -> MaxNet:
    model = MaxNet()
    return model

modelA  = model_SNN()
#print(modelA)
summary(modelA,(3,),1,'cpu')

CNN modality

class VGG16(nn.Module):
    def __init__(self):
        super(VGG16, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1)

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)
        self.conv3_3 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1)

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.conv4_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv4_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_2 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)
        self.conv5_3 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1)

        self.maxpool = nn.AdaptiveAvgPool2d((1,1))


    def forward(self, x):
        x = F.relu(self.conv1_1(x))
        x = F.relu(self.conv1_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv2_1(x))
        x = F.relu(self.conv2_2(x))
        x = self.maxpool(x)
        x = F.relu(self.conv3_1(x))
        x = F.relu(self.conv3_2(x))
        x = F.relu(self.conv3_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv4_1(x))
        x = F.relu(self.conv4_2(x))
        x = F.relu(self.conv4_3(x))
        x = self.maxpool(x)
        x = F.relu(self.conv5_1(x))
        x = F.relu(self.conv5_2(x))
        x = F.relu(self.conv5_3(x))
        x = self.maxpool(x)
        x=x.reshape(x.shape[0],-1)

        return x

def model_VGG() -> VGG16:
    model = VGG16()
    return model

modelC  =model_VGG()
summary(modelC,(3, 224, 224))

The ensembel model:


class MyEnsemble(nn.Module):
    def __init__(self, nb_classes=3):
        super(MyEnsemble, self).__init__()
        self.model_image =  modelC
        self.model_EHR = modelA
              
        # Create a classifier
        self.layer_out = nn.Linear(544, nb_classes)

    
    def forward(self, x1,x3):
        x1 = self.model_image(x1)       
        x3 = self.model_EHR(x3)
        x3 = x3.view(x3.size(0), -1) 

        x = torch.cat((x1,x3), dim=1)
        x = self.layer_out(x)

       
        return x
    
    
def model_snn_vgg() -> MyEnsemble:
    model = MyEnsemble()
    return model



model = model_snn_vgg()
print(model)
model.to(device=DEVICE,dtype=torch.float)
summary(model, [(3, 224, 224),(3,)])

This the error messega:

  Input In [23] in <cell line: 32>
    summary(model, [(3, 224, 224),(3,)])

  File ~\anaconda3\lib\site-packages\torchsummary\torchsummary.py:100 in summary
    total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))

  File <__array_function__ internals>:5 in prod

  File ~\anaconda3\lib\site-packages\numpy\core\fromnumeric.py:3051 in prod
    return _wrapreduction(a, np.multiply, 'prod', axis, dtype, out,

  File ~\anaconda3\lib\site-packages\numpy\core\fromnumeric.py:86 in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)

TypeError: can't multiply sequence by non-int of type 'tuple'

Thanks for any suggestions.

The summary works fine if I use torchinfo and provide the batch dimension:

summary(model, [(1, 3, 224, 224),(1, 3)])

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
MyEnsemble                               [1, 3]                    --
├─VGG16: 1-1                             [1, 512]                  --
│    └─Conv2d: 2-1                       [1, 64, 224, 224]         1,792
│    └─Conv2d: 2-2                       [1, 64, 224, 224]         36,928
│    └─AdaptiveAvgPool2d: 2-3            [1, 64, 1, 1]             --
│    └─Conv2d: 2-4                       [1, 128, 1, 1]            73,856
│    └─Conv2d: 2-5                       [1, 128, 1, 1]            147,584
│    └─AdaptiveAvgPool2d: 2-6            [1, 128, 1, 1]            --
│    └─Conv2d: 2-7                       [1, 256, 1, 1]            295,168
│    └─Conv2d: 2-8                       [1, 256, 1, 1]            590,080
│    └─Conv2d: 2-9                       [1, 256, 1, 1]            590,080
│    └─AdaptiveAvgPool2d: 2-10           [1, 256, 1, 1]            --
│    └─Conv2d: 2-11                      [1, 512, 1, 1]            1,180,160
│    └─Conv2d: 2-12                      [1, 512, 1, 1]            2,359,808
│    └─Conv2d: 2-13                      [1, 512, 1, 1]            2,359,808
│    └─AdaptiveAvgPool2d: 2-14           [1, 512, 1, 1]            --
│    └─Conv2d: 2-15                      [1, 512, 1, 1]            2,359,808
│    └─Conv2d: 2-16                      [1, 512, 1, 1]            2,359,808
│    └─Conv2d: 2-17                      [1, 512, 1, 1]            2,359,808
│    └─AdaptiveAvgPool2d: 2-18           [1, 512, 1, 1]            --
├─MaxNet: 1-2                            [1, 32]                   --
│    └─Sequential: 2-19                  [1, 32]                   --
│    │    └─Sequential: 3-1              [1, 64]                   256
│    │    └─Sequential: 3-2              [1, 48]                   3,120
│    │    └─Sequential: 3-3              [1, 32]                   1,568
│    │    └─Sequential: 3-4              [1, 32]                   1,056
├─Linear: 1-3                            [1, 3]                    1,635
==========================================================================================
Total params: 14,722,323
Trainable params: 14,722,323
Non-trainable params: 0
Total mult-adds (G): 1.96
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 51.41
Params size (MB): 58.89
Estimated Total Size (MB): 110.91
==========================================================================================

Thanks a lot for your suggestion. I didn’t know there was a torchinfo library.

I tried it and it worked!