How can I use inner module outputs without increasing model size

Hi there. I’m gonna keep my explanation short, since I have lots of code to present.

I am trying to develop a new autoencoder architecture, but I am having trouble managing the forward/backward pass size of my model. I’m gonna explain with these test codes:

Imagine we have the two following modules. one returns only the output while the other returns an extra inner output.

class testNet1(nn.Module):
    def __init__(self, dim_in=3, dim_inner=5, dim_out=3):
        super().__init__()
        self.tube1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)
    
    def forward(self, x):
        x1 = self.tube1(x)
        x2 = self.tube1(x)
        x3 = self.tube1(x)
        x4 = self.tube1(x)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv_out(x)
        return x, x1
    

class testNet2(nn.Module):
    def __init__(self, dim_in=3, dim_inner=5, dim_out=3):
        super().__init__()
        self.tube1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)
    
    def forward(self, x):
        x1 = self.tube1(x)
        x2 = self.tube1(x)
        x3 = self.tube1(x)
        x4 = self.tube1(x)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv_out(x)
        return x

These two modules have the exact same number of parameters, the exact same forward/backward pass size (0.04MB) which you can check for yourself using the torchsummary library.

Now, I made these two models, where one uses the testNet1 and the other uses testNet2. in these examples I haven’t used the inner outputs. But according to my tests, using them or not, doesn’t really affect the F/B pass size.

class testModule1(nn.Module):
    def __init__(self):
        super().__init__()
        self.net1_1 = testNet1()
        self.net1_2 = testNet1()
        self.net1_3 = testNet1()

    def forward(self, x):
        x, y = self.net1_1(x)
        x, z = self.net1_1(x)
        x, w = self.net1_1(x)
        return x
    

class testModule2(nn.Module):
    def __init__(self):
        super().__init__()
        self.net2_1 = testNet2()
        self.net2_2 = testNet2()
        self.net2_3 = testNet2()

    def forward(self, x):
        x = self.net2_1(x)
        x = self.net2_1(x)
        x = self.net2_1(x)
        return x

Now these two models have much much different F/B pass sizes. for an arbitrary input size of (3, 64, 64), the module with inner outputs (testModule1) has a F/B pass size of 5.7GB while the second one (testModule2) has a size of 2.44MB.

I have tried detaching the inner output before returning it in hopes of somehow improving the size situation which obviously didn’t help at all.

I want to know how would it be possible to use the inner outputs of a small module in a bigger network without increasing the model size so drastically. I would appreciate any help. Thank you.

I don’t know what these sizes mean, but assume you have calculated the memory requirements during the forward/backward pass. If so, then these might be expected since you are explicitly storing y, z, and w in the first model, which would need to be kept in memory.
Did you check how large these tensors are and do they match the memory calculation you have performed?

Well, these tensors are of shape (5, 64, 64) and I don’t really know what I should be calculating, but 5.7GB does not look right in any shape or form, especially since these values are already created in their respective modules.

I have tried implementing everything in a single class like this:

class testModule3(nn.Module):
    def __init__(self, dim_in, dim_inner, dim_out):
        super().__init__()
        self.tube1_1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube1_2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube1_3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube1_4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out1 = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)

        self.tube2_1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2_2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2_3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2_4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out2 = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)
        
        self.tube3_1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3_2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3_3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3_4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out3 = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)

    def forward(self, x):
        x1_1 = self.tube1_1(x)
        x1_2 = self.tube1_1(x)
        x1_3 = self.tube1_1(x)
        x1_4 = self.tube1_1(x)

        y = x1_1

        x = torch.cat((x1_1, x1_2, x1_3, x1_4), dim=1)
        x = self.conv_out1(x)
        
        x2_1 = self.tube2_1(x)
        x2_2 = self.tube2_1(x)
        x2_3 = self.tube2_1(x)
        x2_4 = self.tube2_1(x)

        w = x2_1

        x = torch.cat((x2_1, x2_2, x2_3, x2_4), dim=1)
        x = self.conv_out2(x)
        
        x3_1 = self.tube3_1(x)
        x3_2 = self.tube3_1(x)
        x3_3 = self.tube3_1(x)
        x3_4 = self.tube3_1(x)

        z = x3_1

        x = torch.cat((x3_1, x3_2, x3_3, x3_4), dim=1)
        x = self.conv_out3(x)
        return x

And this works just fine. the F/B pass size of this class is 2.16MB, which is even lower than before.
in this module, I have the y, w and z variables with exactly the same values and I can use them however I choose to. But when I implement my model in a modular way, the model size explodes.

I want to know why this is happening and how I can fix it.

You would need to explain what these sizes mean and how they were calculated.
Checking the max. allocated memory fits my understanding as seen here:

import torch
import torch.nn as nn

class testNet1(nn.Module):
    def __init__(self, dim_in=3, dim_inner=5, dim_out=3):
        super().__init__()
        self.tube1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)
    
    def forward(self, x):
        x1 = self.tube1(x)
        x2 = self.tube1(x)
        x3 = self.tube1(x)
        x4 = self.tube1(x)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv_out(x)
        return x, x1
    

class testNet2(nn.Module):
    def __init__(self, dim_in=3, dim_inner=5, dim_out=3):
        super().__init__()
        self.tube1 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube3 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube4 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)
        self.tube2 = nn.Conv2d(dim_in, dim_inner, 3, 1, 1)

        self.conv_out = nn.Conv2d(dim_inner*4, dim_out, 3, 1, 1)
    
    def forward(self, x):
        x1 = self.tube1(x)
        x2 = self.tube1(x)
        x3 = self.tube1(x)
        x4 = self.tube1(x)

        x = torch.cat((x1, x2, x3, x4), dim=1)
        x = self.conv_out(x)
        return x

    
print(torch.cuda.max_memory_allocated()/1024**2)
# 0.0
device = "cuda"
h, w = 224, 224

net1 = testNet1().to(device)
x = torch.randn(1, 3, h, w, device=device)
out = net1(x)
print(torch.cuda.max_memory_allocated()/1024**2)
# 9.10302734375

del net1
del x
del out
torch.cuda.reset_max_memory_allocated()
print(torch.cuda.max_memory_allocated()/1024**2)
# 0.0

net2 = testNet2().to(device)
x = torch.randn(1, 3, h, w, device=device)
out = net2(x)
print(torch.cuda.max_memory_allocated()/1024**2)
# 9.10302734375

del net2
del x
del out
torch.cuda.reset_max_memory_allocated()
print(torch.cuda.max_memory_allocated()/1024**2)
# 0.0

class testModule1(nn.Module):
    def __init__(self):
        super().__init__()
        self.net1_1 = testNet1()
        self.net1_2 = testNet1()
        self.net1_3 = testNet1()

    def forward(self, x):
        x, y = self.net1_1(x)
        x, z = self.net1_1(x)
        x, w = self.net1_1(x)
        intermediate_bytes = y.nelement() * y.element_size() + z.nelement() * z.element_size() + w.nelement() * w.element_size()
        print("storing ", intermediate_bytes / 1024**2, " additional MB")
        return x
    

class testModule2(nn.Module):
    def __init__(self):
        super().__init__()
        self.net2_1 = testNet2()
        self.net2_2 = testNet2()
        self.net2_3 = testNet2()

    def forward(self, x):
        x = self.net2_1(x)
        x = self.net2_1(x)
        x = self.net2_1(x)
        return x

mod1 = testModule1().to(device)
x = torch.randn(1, 3, h, w, device=device)
out = mod1(x)
# storing  2.87109375  additional MB
print(torch.cuda.max_memory_allocated()/1024**2)
# 19.83935546875

del mod1
del x
del out
torch.cuda.reset_max_memory_allocated()
print(torch.cuda.max_memory_allocated()/1024**2)
# 0.0

mod2 = testModule2().to(device)
x = torch.randn(1, 3, h, w, device=device)
out = mod2(x)
print(torch.cuda.max_memory_allocated()/1024**2)
# 17.92529296875