I don't no my forward pass and backward pass is too high in my code , that is about Forward/backward pass size (MB): 67652215240073.80

import torch
import torch.nn as nn

class PurposeBlock(nn.Module):
    def __init__(self):
        super(PurposeBlock, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=1)  # Initializes conv1 with input channel = 3
        self.conv2 = nn.Conv2d(3, 1, kernel_size=1)  # Initializes conv2 with input channel = 3
        
        self.Max = nn.MaxPool2d(kernel_size=2, stride=2)
        self.Avg = nn.AvgPool2d(kernel_size=2, stride=2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        # Forward pass computation without re-initializing layers
        x = self.conv1(x)
        y = self.conv2(y)
        
        x_ = self.sigmoid(1 - x)
        y_ = self.sigmoid(1 - y)
        
        inverted_Max = self.Max(y_)
        inverted_Avg = self.Avg(y_)

        # Use element-wise addition instead of matmul for better efficiency
        mul = (inverted_Max + inverted_Avg) + x_

        # Concatenate along channels (dim=1)
        x = torch.cat((x, mul), dim=1)

        return x,y


class CNNBlock(nn.Module):
    def __init__(self):
        super(CNNBlock,self).__init__()
        self.conv1  = None
        self.conv2  = nn.Conv2d(16,32,kernel_size = 3,padding = 1)
        self.conv1_ = None
        self.conv2_ = nn.Conv2d(16,32,kernel_size = 3,padding = 1)

        self.MaxPool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.MaxPool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        
    def forward(self,x,y):
        input_ch1,input_ch2 = x.shape[1],y.shape[1]
        
        self.conv1 = nn.Conv2d(input_ch1,16,kernel_size = 3,padding = 1).to(x.device)
        self.conv1_ = nn.Conv2d(input_ch2,16,kernel_size = 3,padding = 1).to(x.device)
        
        x = self.conv1(x)
        x = self.conv2(x)
        y = self.conv1_(y)
        y = self.conv2_(y)
        
        x_down = self.MaxPool1(x)
        y_down = self.MaxPool2(y)
        
        return x,x_down,y,y_down
class Bottleneck(nn.Module):
    def __init__(self):
        super(Bottleneck, self).__init__()
        self.conv = nn.Conv2d(32,32,kernel_size = 2,stride = 2)
        
    def forward(self, x, y):
        y = self.conv(y)
        x = x + y
        return x 
class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.MaxPool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.MaxPool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)
        self.PB = PurposeBlock()
        self.CB = CNNBlock()

        self.conv1 = nn.Conv2d(32,1,kernel_size = 3,padding = 1)
        self.conv2 = nn.Conv2d(32,1,kernel_size = 3,padding = 1)
        self.conv3 = nn.Conv2d(32,3,kernel_size = 3,padding = 1)
        self.BN = Bottleneck()

        self.MP = nn.MaxPool2d(kernel_size = 1,stride = 2)
        self.concat1 = nn.Conv2d(64,3,kernel_size=3,padding = 1)
        #self.GT = GlobalAttention(196)
        #self.ST = SpatialAttention(32)
        #self.ET = EdgeAttention(64)


        self.deconv3 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(in_channels=3,  out_channels=32, kernel_size=4, stride=2, padding=1)
        self.deconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)

        self.last = nn.Conv2d(in_channels =64,out_channels = 1,kernel_size =3,padding = 1)

    def forward(self,x,y):
        
        #Encoder 
        x  = self.MaxPool1(x)
        y  = self.MaxPool2(y)
        
        x,y = self.PB(x,y)
        x1,x_down,y1,y_down = self.CB(x,y)

        x_down = self.conv3(x_down)
        y_down = self.conv3(y_down)
        x,y = self.PB(x_down,y_down)
        x2,x_down,y2,y_down = self.CB(x,y)

        x_down = self.conv3(x_down)
        y_down = self.conv3(y_down)
        x,y = self.PB(x_down,y_down)
        x3,x_down,y3,y_down = self.CB(x,y)  #x->(28,28,1) and y->(56,56,1)

        #x_down = self.conv1(x_down)
        #y_down = self.conv2(y_down)

        #Bottleneck
        
        NN = self.BN(x_down,y_down)
        print(NN.shape)
     
        ###Middle Concatenation 

        #Global Attention 
        y3 = self.MP(y3)
        x3 = torch.cat((x3,y3),axis = 1)
        x3 = self.concat1(x3)
        #x3 = self.GT(x3)                    #(3, 28, 28)

        #Spatial Attention 
        y2 = self.MP(y2)
        x2 = x2+y2
        #x2 = self.ST(x2)                    #(32, 56, 56)

        #Edge Attention 
        y1 = self.MP(y1)
        x1 = torch.cat((x1,y1),axis = 1)
        #x1 = self.ET(x1)                    #(64, 112, 112)


        up3 = self.deconv3(NN) + x3         #(3, 28, 28)
        

        up2 = self.deconv2(up3) + x2        #(32, 56, 56)

        up1 = self.deconv1(up2) + x1        #(64, 112, 112)

        last = self.last(up1)
        return last
        
model = Model().to(device)
input1 = torch.randn(1,3,112,112).to(device)
input2 = torch.randn(1,3,224,224).to(device)

last = model(input1,input2)

I cannot reproduce any issues and see a memory usage of ~9MB:

last = model(input1,input2)
print(torch.cuda.memory_allocated()/1024**2)
# 8.96533203125

last.mean().backward()
print(torch.cuda.memory_allocated()/1024**2)
# 1.16748046875