When i am using save_tensor , my memory leak

hi,When I am using save_tensor, I have some layer forward, but this layer do not backward. The memory would be leaked. I could not use .detach() .cpu(). Do you have any function to slove the problem.

Could you post a minimal, executable code snippet reproducing this issue, so that we could debug it, please?

Thank for your help. you can run the following code.

from modulefinder import Module
from time import sleep
import torch
import torch.nn as nn

def get_conv_bn(conv=None, bn=None, act=None):
    
    def do_conv_bn(x):
        def pack(y):
            if y.grad_fn is not None and y.grad_fn.next_functions[0][0].name() == "CudnnConvolutionBackward0":
                return x, y, True
            else:
                return x, y, False
        def unpack(pack_retrun):
            x, y, do_conv = pack_retrun
            if do_conv:
                with torch.no_grad():
                    return conv(x)
            else:
                return y
                
        y = conv(x)
        with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
            y = act(bn(y))
        return y
    return do_conv_bn   

class ConvBnRelu(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_save_tensor):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        bn = nn.BatchNorm2d(num_features=out_channels)
        relu = nn.ReLU(inplace=True)
        self.module_list = [conv, bn, relu]
        super(ConvBnRelu, self).__init__(*self.module_list)
        self.use_save_tensor = use_save_tensor
        if self.use_save_tensor:
            self.save_tensor_fn = get_conv_bn(conv, bn, relu)
        else:
            self.save_tensor_fn = None

    def forward(self, x):
        if self.use_save_tensor:
            out = self.save_tensor_fn(x)
        else:
            out = super().forward(x)
        return out 

class Backbone(nn.Module):
    def __init__(self, num_stacks=1, use_save_tensor=True):
        super().__init__()
        self.num_stacks = num_stacks
        self.stack_blocks = nn.ModuleList()
        for i in range(num_stacks):
            self.stack_blocks.append(ConvBnRelu(16, 16, 3, 2, 1, use_save_tensor))
        self.do_not_backward_block = ConvBnRelu(16, 16, 3, 2, 1, use_save_tensor)

    def forward(self, x):
        for i in range(self.num_stacks):
            x = self.stack_blocks[i](x)
        # This layer do forward, but not do backward, leak gpu memory
        do_not_backward_x = self.stack_blocks[i](x)
        return x, do_not_backward_x 
def cal_loss(output, target):
    x, do_not_backward_x  = output
    return (x-target).sum()
    # return (x-target).sum() + (do_not_backward_x-target).sum()

if __name__=="__main__":
    model = Backbone(use_save_tensor=True).cuda()
    input = torch.randn([2, 16, 512, 512], requires_grad=True).cuda()
    for i in range(100):
        print(f"input memory {torch.cuda.memory_allocated()}")
        output = model(input)
        print(f"model memory {torch.cuda.memory_allocated()}")
        loss = cal_loss(output, 1)
        loss.backward()
        print(f"backward memory {torch.cuda.memory_allocated()}")
        sleep(2)

I think the trouble is definitely occurred from

with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
    y = act(bn(y))

Everytime you forward the network, saved_tensors_hooks pushes the computational graph into somewhere in device(in this case, GPU).
Then you need to pop them by calling either torch.Tensor.backward() or torch.autograd().grad().

Consequently, in cal_loss function, you missed to bring do_not_backward_x in your backward shopping cart so that the computational graph corresponding to do_not_backward_x couldn’t pop its computational graph which is pushed by calling torch.autograd.graph.saved_tensors_hooks.

Probably, you already found how it changes when using return (x-target).sum() + (do_not_backward_x-target).sum() instead of return (x-target).sum().

Thank for your replay. My point of view is the same as yours. I can use (x+do_not_backward_x*0-target).sum() instead of (x-target).sum(). But this kind of change is not very elegant. This modification will introduce many hard code. Save tensor will not be used as a plugin. I hope solve the problem in save_tensor’ s code or autograd.

Hope this could be help

from time import sleep
import torch
import torch.nn as nn

def get_conv_bn(conv=None, bn=None, act=None):

    def do_conv_bn(x):
        def pack(y):
            if y.grad_fn is not None and y.grad_fn.next_functions[0][0].name() == "CudnnConvolutionBackward0":
                return x, y, True
            else:
                return x, y, False
        def unpack(pack_retrun):
            x, y, do_conv = pack_retrun
            if do_conv:
                with torch.no_grad():
                    return conv(x)
            else:
                return y

        y = conv(x)
        with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
            y = act(bn(y))
        return y
    return do_conv_bn

class ConvBnRelu(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        bn = nn.BatchNorm2d(num_features=out_channels)
        relu = nn.ReLU(inplace=True)
        self.module_list = [conv, bn, relu]
        super(ConvBnRelu, self).__init__(*self.module_list)
        self.save_tensor_fn = get_conv_bn(conv, bn, relu)

    def forward(self, x, use_save_tensor=True):
        if use_save_tensor:
            out = self.save_tensor_fn(x)
        else:
            out = super().forward(x)
        return out

class Backbone(nn.Module):
    def __init__(self, num_stacks=1):
        super().__init__()
        self.num_stacks = num_stacks
        self.stack_blocks = nn.ModuleList()
        for i in range(num_stacks):
            self.stack_blocks.append(ConvBnRelu(16, 16, 3, 2, 1))
        self.do_not_backward_block = ConvBnRelu(16, 16, 3, 2, 1)

    def forward(self, x):
        for i in range(self.num_stacks):
            x = self.stack_blocks[i](x, use_save_tensor=True)
        # This layer do forward, but not do backward, leak gpu memory
        do_not_backward_x = self.stack_blocks[i](x, use_save_tensor=False)
        return x, do_not_backward_x

def cal_loss(output, target):
    x, do_not_backward_x  = output
    return (x-target).sum()
    #return (x-target).sum() + (do_not_backward_x-target).sum()

if __name__=="__main__":
    model = Backbone().cuda()
    input = torch.randn([2, 16, 512, 512], requires_grad=True).cuda()
    for i in range(100):
        print(f"input memory {torch.cuda.memory_allocated()}")
        output = model(input)
        print(f"model memory {torch.cuda.memory_allocated()}")
        loss = cal_loss(output, 1)
        loss.backward()
        print(f"backward memory {torch.cuda.memory_allocated()}")
        sleep(2)

Thank for your replay. Your code is right. But my example is only a minimal executable code.In my usage scenario, my backbone and neck have many conv-bn-relu blocks, I could not set useless save_tensor for every block which do not backward.
I have a question, when the block use norm forward, the block which do not backward do not leak memory. when the block uses save-tensor forward, the block will leak memory.