Hi, I am trying to run my code in “torch.no_grad” and “eval” model. I want to save memory via operations like x = conv(x). But while I running the following code:
x = self.conv1(x)
x = self.conv2(x)
I found it would consume 3 times memory as much as x, it seems that the first x is not be replaced. But Ideally, it could be reduced to 2 times (since conv operations need extra memory). I have finished the training process, how could I free the memory of the first x after ops like x = conv(x)?
thx!
I doubt you can save more memory than is already saved by wrapping the code in a with torch.no_grad() block. The output of conv1 has to be allocated at one place during the forward pass.
How many output channels are you using?
Sorry, I didn’t make myself clear. The code is like this:
def forward(self, x):
x = conv1(x)
x = conv2(x)
After calculating the conv1, the memory of the very first x should be free since it is useless anynore. But pytorch dose not free it. Is any solution to free the input x after the conv1.
If x is not needed anymore, PyTorch will free it and reuse the memory.
In case you doubt it, you could try to assign the result of conv1 to another variable and delete x manually.
Actually I test it based on a resnet architecture.
The whole code would like this:
Class BasicBlock(nn.Module):
def __init__():
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(64, 64, 3)
self.conv2 = nn.Conv2d(64, 64, 3)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(x)
return x
class WholeModel(nn.Module):
def __init__():
super(WholeModel, self).__init__()
self.block1 = BasicBlock()
self.block2 = BasicBlock()
def forward(self, x):
out = self.block1(x)
out = self.block2(out)
While runing the “out = self.block1(x)”, it saves x. And after the first conv in block1: “out = self.conv1(x)”. x is useless anymore. But the input x will not be freed until “out = self.block1(x)” is finished. I think it could be freed earlier since it is useless anymore after the conv1 in block1 is finished. I tried to “del x” after the “out = self.conv1(x)”, but it is obviously useless. Do you have any ideas about how to free the memory just after the “out = self.conv1(x)”? Thank youx