Can somebody help me, I wrote code by myself for some segmentation task, but while this code the GPU memory gets filled up after 3-4 epochs. I am not getting what’s the issue.
code is here
class Backbone(torch.nn.Module):
def __init__(self):
super(Backbone,self).__init__()
# self.feature_list=None
#download the model (resnet50)
#take out the calculated features from selected laye
#the feature size will be 32x32, 16x16, 8x8
self.model = torchvision.models.resnet50(pretrained=True)
for param in self.model.parameters():
param.requires_grad = False
def feature_hook(self,feature_list):
def hook(module, input, output):
feature_list.append(output)
# print(output)
return hook
def forward(self,x):
feature_list=[]
#hooks
self.model.layer2.register_forward_hook(self.feature_hook(feature_list))
self.model.layer3.register_forward_hook(self.feature_hook(feature_list))
self.model.layer4.register_forward_hook(self.feature_hook(feature_list))
self.model(x)
return feature_list
class ColumnDecoder(nn.Module):
def __init__(self,in_channle=512,out_channel=1,up_conv_in_channel = 1280):
super(ColumnDecoder,self).__init__()
self.col_cov_7 = nn.Sequential(nn.Conv2d(in_channle,in_channle, kernel_size=1),
nn.ReLU(inplace=True),
nn.Dropout(0.8),
nn.Conv2d(in_channle, in_channle, kernel_size=1),
nn.ReLU(inplace=True))
self.col_up_conv = nn.ConvTranspose2d(up_conv_in_channel, out_channel, kernel_size=2, stride=2, dilation=1)
def forward(self, x, feature_32, feature_16):
out = self.col_cov_7(x)
out = F.interpolate(out,scale_factor=2)
out = torch.cat([out, feature_16], dim=1)
out = F.interpolate(out,scale_factor=2)
out = torch.cat([out, feature_32], dim=1)
out = F.interpolate(out,scale_factor=2)
out = F.interpolate(out,scale_factor=2)
return self.col_up_conv(out)
class TableDecoder(nn.Module):
def __init__(self,in_channle=512,out_channel=1,up_conv_in_channel = 1280):
super(TableDecoder,self).__init__()
self.tab_cov_7 = nn.Sequential(nn.Conv2d(in_channle,in_channle, kernel_size=1),
nn.ReLU(inplace=True))
self.tab_up_conv = nn.ConvTranspose2d(up_conv_in_channel, out_channel, kernel_size=2, stride=2, dilation=1)
def forward(self, x, feature_32, feature_16):
out = self.tab_cov_7(x)
out = F.interpolate(out,scale_factor=2)
out = torch.cat([out, feature_16], dim=1)
out = F.interpolate(out,scale_factor=2)
out = torch.cat([out, feature_32], dim=1)
out = F.interpolate(out,scale_factor=2)
out = F.interpolate(out,scale_factor=2)
return self.tab_up_conv(out)
class TableNet(torch.nn.Module):
def __init__(self,input_size=(3,256,256)):
super(TableNet,self).__init__()
#getting instance of backbone
self.base_model = Backbone()
#get the output feature size and channel
test_in = torch.Tensor([np.random.rand(*input_size)])
feature_32, feature_16, feature_8 = self.base_model(test_in)
print(feature_32.shape,feature_16.shape,feature_8.shape)
feature_8_out_channel = feature_8.shape[1]
feature_16_out_channel = feature_16.shape[1]
feature_32_out_channel = feature_32.shape[1]
self.mid_cov_5_6 = nn.Sequential(nn.Conv2d(feature_8_out_channel, feature_8_out_channel, kernel_size=1),
nn.ReLU(inplace=True),
nn.Dropout(0.8),
nn.Conv2d(feature_8_out_channel, feature_8_out_channel, kernel_size=1),
nn.ReLU(inplace=True),
nn.Dropout(0.8))
#get the output feature size and channel
up_conv_in_channel = feature_8_out_channel+feature_16_out_channel+feature_32_out_channel
self.column_decoder = ColumnDecoder(in_channle=feature_8_out_channel,up_conv_in_channel=up_conv_in_channel)
self.table_decoder = TableDecoder(in_channle=feature_8_out_channel,up_conv_in_channel=up_conv_in_channel)
def forward(self,x):
feature_32, feature_16, feature_8 = self.base_model(x)
feature_8_out = self.mid_cov_5_6(feature_8)
col_out = self.column_decoder(feature_8_out,feature_32,feature_16)
tab_out = self.table_decoder(feature_8_out,feature_32,feature_16)
return tab_out,col_out