Use keepdim=True in my own std function causes gpu memory leak

Hi, I am using my own custom std function for some reason.
I used to let keepdim=False as default, and everything worked just fine. However, I find that if keepdim=True, the gpu memory usage just keep going up every iteration and finally explode.

The function looks like this:

class CustomStd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, dim=None, keepdim=False, eps=1e-5, unbiased=True):
        if dim is None:
            dim = tuple([i for i in range(input.dim())])
        dev = input - input.mean(dim=dim, keepdim=True)
        ctx.save_for_backward(input)
        ctx.eps=eps
        ctx.dev = dev
        ctx.numdim = input.dim()
        ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
        if unbiased:
            ctx.numel -= 1
        ctx.std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
        return ctx.std if keepdim else ctx.std.squeeze()

    @staticmethod
    def backward(ctx, grad_output):
        input,= ctx.saved_tensors
        grad_input = grad_output
        for i in range(grad_output.dim(), ctx.numdim):
            grad_input = grad_input.unsqueeze(i)
        grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (ctx.std + ctx.eps) * grad_input
        return grad_input, None, None, None, None

I assume that torch.squeeze() creates views onto the tensor and therefore the memory usage should be the same.
Am I missing anything here?
Thanks in advance.

Hi,

In your code, if keepdim=True, then self.std is the output. And ctx.save_for_backawrd() must be used for input/output otherwise you will get such memory leak :confused:
So you should just make sure to not save the output in ctx.foo and it will go away :slight_smile:

But I do want to save the std in ctx since it will be used in backward pass.
In this case, should I use ctx.save_for_backward(input, ctx.std) just before returning in forward pass?

Hi,

In this case, you should not save it on ctx at all and only with save_for_backward:
I couldn’t run the code so there might be typos. But this is the idea:

class CustomStd(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, dim=None, keepdim=False, eps=1e-5, unbiased=True):
        if dim is None:
            dim = tuple([i for i in range(input.dim())])
        dev = input - input.mean(dim=dim, keepdim=True)
        ctx.eps=eps
        ctx.dev = dev
        ctx.numdim = input.dim()
        ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
        if unbiased:
            ctx.numel -= 1
        std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
        ctx.std_shape = std.shape
        res = std if keepdim else std.squeeze()
        ctx.save_for_backward(input, res)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        input, res = ctx.saved_tensors
        std = res.view(std_shape)
        grad_input = grad_output
        for i in range(grad_output.dim(), ctx.numdim):
            grad_input = grad_input.unsqueeze(i)
        grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (std + ctx.eps) * grad_input
        return grad_input, None, None, None, None