Inplace relu doesn't work for custom autograd functions

When we use nn.ReLU(inplace=True), it should not improve the GPU memory for back-propagation. This is true for convolutional neural networks.

However, the inplace mode doesn’t work for custom functions.
Specifically, I build a custom autograd function and a custom module based on it.
When I combine this module and inplace relu, it seems that nn.ReLU(inplace=True) doesn’t save memory for BP. Here is an example:

net_1 = nn.Sequential()
net_2 = nn.Sequential()
for i in range(50):
    net_1.add_module("conv_{}".format(i), MyConv2d(32, 32, 3))
    net_2.add_module("conv_{}".format(i), MyConv2d(32, 32, 3))
    net_1.add_module("relu_{}".format(i), nn.ReLU(True))
    net_2.add_module("relu_{}".format(i), nn.ReLU(False))

Then I get the max_memory_allocated for net_1 and net_2 given the same input.
And I find the used memories are the same.
net_1: 470MB
net_2: 470MB
How can I modify my custom function to make nn.ReLU(inplace=True) works?

Did you try the same code with a built-in CNN instead of MyConv2d? Does that report a significant difference in max_memory_allocated?

@gphilip
If replace MyConv2d with the build-in CNN, inplace mode can save half of the memory.

net_1: 240 MB
net_2: 470 MB

1 Like

Can someone point out the reason for this?

Could you post an executable code snippet to reproduce the issue, please?

@gphilip
@ptrblck
This is the executable code:

import torch
from torch import nn

class conv(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        y = 2 * x
        ctx.save_for_backward(y)
        return torch.sin(y)
    
    @staticmethod
    def backward(ctx, grad_out):
        y = ctx.saved_tensors
        with torch.no_grad():
            if ctx.needs_input_grad[0]:
                return grad_out * torch.cos(y)
            else:
                return None
    

class MyConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.parameter.Parameter(torch.randn(1, 1, 64, 64))
        self.main = conv.apply

    def forward(self, x):
        return self.main(x) * self.weight


def evaluate_memory(module):
    x = torch.randn(32, 32, 64, 64).cuda()
    torch.cuda.reset_peak_memory_stats()
    y = module(x)
    return torch.cuda.max_memory_allocated() / (1024 ** 2)

net_1 = nn.Sequential()
net_2 = nn.Sequential()

use_default_conv = False

for i in range(50):
    if use_default_conv:
        net_1.add_module("conv_{}".format(i), nn.Conv2d(32, 32, 3, padding=1))
        net_2.add_module("conv_{}".format(i), nn.Conv2d(32, 32, 3, padding=1))
    else:
        net_1.add_module("conv_{}".format(i), MyConv())
        net_2.add_module("conv_{}".format(i), MyConv())
    net_1.add_module("relu_{}".format(i), nn.ReLU(True))
    net_2.add_module("relu_{}".format(i), nn.ReLU(False))

net_1.cuda()
net_2.cuda()

m_1 = evaluate_memory(net_1)
m_2 = evaluate_memory(net_2)
print("memory for in-place mode: {:.1f}".format(m_1))
print("memory for out-place mode: {:.1f}".format(m_2))

when set use_defualt_conv=True, the output is

memory for in-place mode: 819.6
memory for out-place mode: 1619.6

when set use_defualt_conv=False, the output is

memory for in-place mode: 2401.6
memory for out-place mode: 2417.6

My point is that when set use_defualt_conv=False, the memory of out-place mode should be half of in-place mode. However, it is not.

(I think you meant to alert @ptrblck ; he is the one who asked for the code.)

Thank you for the code. After playing with it for a bit I discovered something interesting: If I set use_default_conv=False and change

for i in range(50):

to

for i in range(50000):

then the run takes much more time to complete (as expected), but the memory use reported still remains at 48.0.

This suggests to me that the number of floats in MyConv is too small to cause any real gains when using ReLU in place, even when there are a 50000 copies of MyConv in the net. This may be the explanation.

@gphilip

Thanks for your remind.
To verify your guess, I modify MyConv. (See the previous post.)
After modification, the memory usage is proportional to the number of layers and it is larger than the default Conv.

when set use_defualt_conv=True, the output is

memory for in-place mode: 819.6
memory for out-place mode: 1619.6

when set use_defualt_conv=False, the output is

memory for in-place mode: 2401.6
memory for out-place mode: 2417.6

Those results suggest that in-place mode of Relu doesn’t work for custom autograd functions.
In fact, my real custom function is more complex than the default Conv. I just provide an easy example to explain the problem.

your example should be changed to:

    def forward(ctx, x):
        y = torch.sin(2*x)
        if ctx.needs_input_grad[0]: ctx.save_for_backward(y)
        return y

as your version keeps both 2x and sin(2x) tensors. I’m not sure about relu’s role.

1 Like

Thanks for the modified code. I am getting different results than you when I run a slightly modified version of your code. My results suggest that ReLU being in-place does make a significant difference, even for custom modules.

Here is the modified code that I ran:

import torch
from torch import nn

class MyConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.parameter.Parameter(torch.randn(1, 1, 64, 64))
        

    def forward(self, x):
        return (x * self.weight)


def evaluate_memory(module):
    x = torch.rand(32, 32, 64, 64).cuda()
    torch.cuda.reset_peak_memory_stats()
    y = module(x)
    return torch.cuda.max_memory_allocated(torch.device("cuda")) / (1024 ** 2)

net_1 = nn.Sequential()
net_2 = nn.Sequential()

use_default_conv = True

for i in range(50):
    if use_default_conv:
        net_1.add_module("conv_{}".format(i), nn.Conv2d(32, 32, 3, padding=1))
        net_2.add_module("conv_{}".format(i), nn.Conv2d(32, 32, 3, padding=1))
    else:
        net_1.add_module("conv_{}".format(i), MyConv())
        net_2.add_module("conv_{}".format(i), MyConv())
    net_1.add_module("relu_{}".format(i), nn.ReLU(True))
    net_2.add_module("relu_{}".format(i), nn.ReLU(False))

net_1.cuda()
net_2.cuda()

m_1 = evaluate_memory(net_1)
m_2 = evaluate_memory(net_2)
print("memory for in-place mode: {:.1f}".format(m_1))
print("memory for out-place mode: {:.1f}".format(m_2))

When I run this with use_default_conv = True I get:

memory for in-place mode: 819.6
memory for out-place mode: 1619.6

When I run this with use_default_conv = False I get:

memory for in-place mode: 817.6
memory for out-place mode: 1617.6

This tells that ReLU being in-place does make the same degree of difference, in both cases.

@gphilip
In your case, custom function is not used.
Here, custom function means custom autograd function in PyTorch.
That is a sub-calss torch.autograd.Function. See conv in my code.

Thanks for pointing that out, I missed it altogether!

Edited to add:

If you change the line

ctx.save_for_backward(y)

to

ctx.save_for_backward(x)

then there is still a significant difference in the two memory usages reported (with and without in-place, respectively). This suggests that while using the in-place ReLU does make a difference, this difference is swamped by the other memory being used (and not freed) by your forward method.