import torch
import torch.nn as nn
class PurposeBlock(nn.Module):
def __init__(self):
super(PurposeBlock, self).__init__()
self.conv1 = nn.Conv2d(3, 1, kernel_size=1) # Initializes conv1 with input channel = 3
self.conv2 = nn.Conv2d(3, 1, kernel_size=1) # Initializes conv2 with input channel = 3
self.Max = nn.MaxPool2d(kernel_size=2, stride=2)
self.Avg = nn.AvgPool2d(kernel_size=2, stride=2)
self.sigmoid = nn.Sigmoid()
def forward(self, x, y):
# Forward pass computation without re-initializing layers
x = self.conv1(x)
y = self.conv2(y)
x_ = self.sigmoid(1 - x)
y_ = self.sigmoid(1 - y)
inverted_Max = self.Max(y_)
inverted_Avg = self.Avg(y_)
# Use element-wise addition instead of matmul for better efficiency
mul = (inverted_Max + inverted_Avg) + x_
# Concatenate along channels (dim=1)
x = torch.cat((x, mul), dim=1)
return x,y
class CNNBlock(nn.Module):
def __init__(self):
super(CNNBlock,self).__init__()
self.conv1 = None
self.conv2 = nn.Conv2d(16,32,kernel_size = 3,padding = 1)
self.conv1_ = None
self.conv2_ = nn.Conv2d(16,32,kernel_size = 3,padding = 1)
self.MaxPool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.MaxPool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)
def forward(self,x,y):
input_ch1,input_ch2 = x.shape[1],y.shape[1]
self.conv1 = nn.Conv2d(input_ch1,16,kernel_size = 3,padding = 1).to(x.device)
self.conv1_ = nn.Conv2d(input_ch2,16,kernel_size = 3,padding = 1).to(x.device)
x = self.conv1(x)
x = self.conv2(x)
y = self.conv1_(y)
y = self.conv2_(y)
x_down = self.MaxPool1(x)
y_down = self.MaxPool2(y)
return x,x_down,y,y_down
class Bottleneck(nn.Module):
def __init__(self):
super(Bottleneck, self).__init__()
self.conv = nn.Conv2d(32,32,kernel_size = 2,stride = 2)
def forward(self, x, y):
y = self.conv(y)
x = x + y
return x
class Model(nn.Module):
def __init__(self):
super(Model,self).__init__()
self.MaxPool1 = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.MaxPool2 = nn.MaxPool2d(kernel_size = 2,stride = 2)
self.PB = PurposeBlock()
self.CB = CNNBlock()
self.conv1 = nn.Conv2d(32,1,kernel_size = 3,padding = 1)
self.conv2 = nn.Conv2d(32,1,kernel_size = 3,padding = 1)
self.conv3 = nn.Conv2d(32,3,kernel_size = 3,padding = 1)
self.BN = Bottleneck()
self.MP = nn.MaxPool2d(kernel_size = 1,stride = 2)
self.concat1 = nn.Conv2d(64,3,kernel_size=3,padding = 1)
#self.GT = GlobalAttention(196)
#self.ST = SpatialAttention(32)
#self.ET = EdgeAttention(64)
self.deconv3 = nn.ConvTranspose2d(in_channels=32, out_channels=3, kernel_size=4, stride=2, padding=1)
self.deconv2 = nn.ConvTranspose2d(in_channels=3, out_channels=32, kernel_size=4, stride=2, padding=1)
self.deconv1 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
self.last = nn.Conv2d(in_channels =64,out_channels = 1,kernel_size =3,padding = 1)
def forward(self,x,y):
#Encoder
x = self.MaxPool1(x)
y = self.MaxPool2(y)
x,y = self.PB(x,y)
x1,x_down,y1,y_down = self.CB(x,y)
x_down = self.conv3(x_down)
y_down = self.conv3(y_down)
x,y = self.PB(x_down,y_down)
x2,x_down,y2,y_down = self.CB(x,y)
x_down = self.conv3(x_down)
y_down = self.conv3(y_down)
x,y = self.PB(x_down,y_down)
x3,x_down,y3,y_down = self.CB(x,y) #x->(28,28,1) and y->(56,56,1)
#x_down = self.conv1(x_down)
#y_down = self.conv2(y_down)
#Bottleneck
NN = self.BN(x_down,y_down)
print(NN.shape)
###Middle Concatenation
#Global Attention
y3 = self.MP(y3)
x3 = torch.cat((x3,y3),axis = 1)
x3 = self.concat1(x3)
#x3 = self.GT(x3) #(3, 28, 28)
#Spatial Attention
y2 = self.MP(y2)
x2 = x2+y2
#x2 = self.ST(x2) #(32, 56, 56)
#Edge Attention
y1 = self.MP(y1)
x1 = torch.cat((x1,y1),axis = 1)
#x1 = self.ET(x1) #(64, 112, 112)
up3 = self.deconv3(NN) + x3 #(3, 28, 28)
up2 = self.deconv2(up3) + x2 #(32, 56, 56)
up1 = self.deconv1(up2) + x1 #(64, 112, 112)
last = self.last(up1)
return last
model = Model().to(device)
input1 = torch.randn(1,3,112,112).to(device)
input2 = torch.randn(1,3,224,224).to(device)
last = model(input1,input2)
I cannot reproduce any issues and see a memory usage of ~9MB:
last = model(input1,input2)
print(torch.cuda.memory_allocated()/1024**2)
# 8.96533203125
last.mean().backward()
print(torch.cuda.memory_allocated()/1024**2)
# 1.16748046875