Hi, I am using my own custom std function for some reason.
I used to let keepdim=False as default, and everything worked just fine. However, I find that if keepdim=True, the gpu memory usage just keep going up every iteration and finally explode.
The function looks like this:
class CustomStd(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False, eps=1e-5, unbiased=True):
if dim is None:
dim = tuple([i for i in range(input.dim())])
dev = input - input.mean(dim=dim, keepdim=True)
ctx.save_for_backward(input)
ctx.eps=eps
ctx.dev = dev
ctx.numdim = input.dim()
ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
if unbiased:
ctx.numel -= 1
ctx.std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
return ctx.std if keepdim else ctx.std.squeeze()
@staticmethod
def backward(ctx, grad_output):
input,= ctx.saved_tensors
grad_input = grad_output
for i in range(grad_output.dim(), ctx.numdim):
grad_input = grad_input.unsqueeze(i)
grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (ctx.std + ctx.eps) * grad_input
return grad_input, None, None, None, None
I assume that torch.squeeze() creates views onto the tensor and therefore the memory usage should be the same.
Am I missing anything here?
Thanks in advance.
In your code, if keepdim=True, then self.std is the output. And ctx.save_for_backawrd() must be used for input/output otherwise you will get such memory leak
So you should just make sure to not save the output in ctx.foo and it will go away
But I do want to save the std in ctx since it will be used in backward pass.
In this case, should I use ctx.save_for_backward(input, ctx.std) just before returning in forward pass?
In this case, you should not save it on ctx at all and only with save_for_backward:
I couldn’t run the code so there might be typos. But this is the idea:
class CustomStd(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim=None, keepdim=False, eps=1e-5, unbiased=True):
if dim is None:
dim = tuple([i for i in range(input.dim())])
dev = input - input.mean(dim=dim, keepdim=True)
ctx.eps=eps
ctx.dev = dev
ctx.numdim = input.dim()
ctx.numel = functools.reduce(lambda x, y: x * y, [input.size(d) for d in dim])
if unbiased:
ctx.numel -= 1
std = torch.sqrt(torch.sum(dev * dev, dim=dim, keepdim=True) / ctx.numel)
ctx.std_shape = std.shape
res = std if keepdim else std.squeeze()
ctx.save_for_backward(input, res)
return res
@staticmethod
def backward(ctx, grad_output):
input, res = ctx.saved_tensors
std = res.view(std_shape)
grad_input = grad_output
for i in range(grad_output.dim(), ctx.numdim):
grad_input = grad_input.unsqueeze(i)
grad_input = ctx.dev * (ctx.numel - 1) / (ctx.numel**2) / (std + ctx.eps) * grad_input
return grad_input, None, None, None, None