How to make memory efficient operations?

Hi! I’m playing with custom activations and I used the following function instead of ReLU. After doing this I’m running out of GPU memory on Cuda.

def custom_activation(x):
    return x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))

How can I make this code more memory efficient? May be you can recommend some in-place operations which I can use here?
Or may be there are more efficient ways to implement custom activations? Or may be there are some ways to manage caching on GPU?

Was the code running fine before and raises an out of memory error, if you use these methods?

I am sure that these methods cause the error, because this error is raised when I replace relu activations to this one.
There is even more. If I use this implementation of swish I run out of memory:

def swish(input):
    return input * torch.sigmoid(input)

But if I change it as follows everything fits into memory:

def swish(input, inplace = False):
input = input * torch.sigmoid_(input)
    return input

So I am looking for the ways to optimize tanh and softmax and make those more memory efficient.

The peak memory usage will be higher in the first method, as intermediate tensors are created.
However, the result will not be the same!
Have a look at this code snippet:

def swish(input):
    return input * torch.sigmoid(input)

def swish_inplace(input):
    input = input * torch.sigmoid_(input)
    return input


x = torch.randn(1000, 1000, device='cuda')

torch.cuda.synchronize()
print('Start')
print(torch.cuda.max_memory_allocated() / 1024**2)
print(torch.cuda.memory_allocated() / 1024**2)

output = swish(x)
# output = swish_inplace(x)

torch.cuda.synchronize()
print('After swish')
print(torch.cuda.max_memory_allocated() / 1024**2)
print(torch.cuda.memory_allocated() / 1024**2)

For the first run we’ll get:

Start
4.64013671875
4.64013671875
After swish
12.27001953125
8.455078125

Let’s simplify the numbers a bit and claim our tensor uses ~4MB.
As you can see, the memory usage after the function call will be ~8MB, which is fine, since we are still holding to x and created the return tensor output.
However, you can also see that the peak memory usage was actually ~12MB, which would explain the intermediate tensor of torch.sigmoid(input).

Now let’s swap the vanilla swish for its inplace method and run it again:

Start
4.64013671875
4.64013671875
After swish
8.455078125
8.455078125

Awesome! It looks like we are indeed saving some peak memory so we should be using the inplace method, shouldn’t we?
Well, let’s have a look at the results first:

x = torch.randn(1000, 1000, device='cuda')
output = swish(x)
# Clone x to keep the original values for comparison reasons
output_ = swish_inplace(x.clone())

# Compare outputs
print((output == (x * torch.sigmoid(x))).all())
> tensor(1, device='cuda:0', dtype=torch.uint8)
print((output_ == (torch.sigmoid(x) * torch.sigmoid(x))).all())
> tensor(1, device='cuda:0', dtype=torch.uint8)

Ohoh! The inplace method manipulates x inplace, so that we are in fact calculating torch.sigmoid(x)*torch.sigmoid(x), which is wrong!

While inplace methods might save some peak memory usage, I’m always careful about using them and always check the result.

4 Likes

Thanks a lot! Inplace methods are really tricky.
Could you suggest some techniques to save some memory for tanh and softmax operations?