Out of memory on GPU, need to save intermediate steps for links in network

Hi. I have trouble with a network structure that requires connections between multiple levels. I am running out of memory on the GPU which has 8 gb. I was wondering if there is anything wrong with the way I am doing this. Below is the code for the network.

class i2i(nn.Module):

	def __init__(self):
		super(i2i, self).__init__()
		
		# Pooling layer, avg pooling
		self.pool = nn.AvgPool3d(2)

		# Layers going down
		# Number in name represents deepness level in network
		self.d1 = nn.Sequential(
			nn.Conv3d(1, 32, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(32, 32, 3, padding=1),
			nn.ReLU(True)
			)
		self.d2 = nn.Sequential(
			nn.Conv3d(32, 128, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(128, 128, 3, padding=1),
			nn.ReLU(True),
			)
		self.d3 = nn.Sequential(
			nn.Conv3d(128, 256, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(256, 256, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(256, 256, 3, padding=1),
			nn.ReLU(True)
			)
		self.d4 = nn.Sequential(
			nn.Conv3d(256, 512, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(512, 512, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(512, 512, 3, padding=1),
			nn.ReLU(True)
			)

		# Upscaling pre-merge Layers
		self.merge3_up = nn.ConvTranspose3d(512, 512, 4, stride=2, groups=512)
		self.merge2_up = nn.ConvTranspose3d(256, 256, 4, stride=2, groups=256)
		self.merge1_up = nn.ConvTranspose3d(128, 128, 4, stride=2, groups=128)

		# Layers going up
		self.u3 = nn.Sequential(
			nn.Conv3d(768, 256, 1),
			nn.ReLU(True),
			nn.Conv3d(256, 256, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(256, 256, 3, padding=1),
			nn.ReLU(True)
			)
		self.u2 = nn.Sequential(
			nn.Conv3d(384, 128, 1),
			nn.ReLU(True),
			nn.Conv3d(128, 128, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(128, 128, 3, padding=1),
			nn.ReLU(True)
			)
		self.u1 = nn.Sequential(
			nn.Conv3d(160, 32, 1),
			nn.ReLU(True),
			nn.Conv3d(32, 32, 3, padding=1),
			nn.ReLU(True),
			nn.Conv3d(32, 32, 3, padding=1),
			nn.ReLU(True)
			)
		self.s1 = nn.Conv3d(32, 1, 1)

		# Upscaling score layers
		self.s2_up = nn.Sequential(
			nn.Conv3d(128, 1, 1),
			nn.ConvTranspose3d(1, 1, 4, stride=2, groups=1)
			)
		self.s3_up = nn.Sequential(
			nn.Conv3d(256, 1, 1),
			nn.ConvTranspose3d(1, 1, 8, stride=4, groups=1)
			)
		self.s4_up = nn.Sequential(
			nn.Conv3d(512, 1, 1),
			nn.ConvTranspose3d(1, 1, 16, stride=8, groups=1)
			)


	def forward(self, x):
		d1 = self.d1(x)
		d2 = self.d2(self.pool(d1))
		d3 = self.d3(self.pool(d2))
		d4 = self.d4(self.pool(d3))

		# Score lvl 4
		s4 = F.pad( self.s4_up(d4) , (-4,-4,-4,-4,-4,-4))

		# Merge (upscale + concatenate) and ascend
		u3 = self.u3( torch.cat(( F.pad( self.merge3_up(d4), (-1,-1,-1,-1,-1,-1)) , d3),1))

		# Score lvl 3
		s3 = F.pad( self.s3_up(u3) , (-2,-2,-2,-2,-2,-2))

		# Merge (upscale + concatenate) and ascend
		u2 = self.u2( torch.cat(( F.pad( self.merge2_up(u3), (-1,-1,-1,-1,-1,-1)) , d2),1))

		# Score lvl 2
		s2 = F.pad( self.s2_up(u2) , (-1,-1,-1,-1,-1,-1))

		# Merge (upscale + concatenate) and ascend
		u1 = self.u1( torch.cat(( F.pad( self.merge1_up(u2), (-1,-1,-1,-1,-1,-1)) , d1),1))

		# Score lvl 1
		s1 = self.s1(u1)

		return s1, s2, s3, s4

It seems that you have a large model. Maybe recomputing in backward can help you save some memory. Here is a great PR to that end along with discussion and links to examples:

Best regards

Thomas

I will look in to this. Thank you.