The peak memory usage will be higher in the first method, as intermediate tensors are created.
However, the result will not be the same!
Have a look at this code snippet:
def swish(input):
return input * torch.sigmoid(input)
def swish_inplace(input):
input = input * torch.sigmoid_(input)
return input
x = torch.randn(1000, 1000, device='cuda')
torch.cuda.synchronize()
print('Start')
print(torch.cuda.max_memory_allocated() / 1024**2)
print(torch.cuda.memory_allocated() / 1024**2)
output = swish(x)
# output = swish_inplace(x)
torch.cuda.synchronize()
print('After swish')
print(torch.cuda.max_memory_allocated() / 1024**2)
print(torch.cuda.memory_allocated() / 1024**2)
For the first run we’ll get:
Start
4.64013671875
4.64013671875
After swish
12.27001953125
8.455078125
Let’s simplify the numbers a bit and claim our tensor uses ~4MB.
As you can see, the memory usage after the function call will be ~8MB, which is fine, since we are still holding to x
and created the return tensor output
.
However, you can also see that the peak memory usage was actually ~12MB, which would explain the intermediate tensor of torch.sigmoid(input)
.
Now let’s swap the vanilla swish
for its inplace method and run it again:
Start
4.64013671875
4.64013671875
After swish
8.455078125
8.455078125
Awesome! It looks like we are indeed saving some peak memory so we should be using the inplace method, shouldn’t we?
Well, let’s have a look at the results first:
x = torch.randn(1000, 1000, device='cuda')
output = swish(x)
# Clone x to keep the original values for comparison reasons
output_ = swish_inplace(x.clone())
# Compare outputs
print((output == (x * torch.sigmoid(x))).all())
> tensor(1, device='cuda:0', dtype=torch.uint8)
print((output_ == (torch.sigmoid(x) * torch.sigmoid(x))).all())
> tensor(1, device='cuda:0', dtype=torch.uint8)
Ohoh! The inplace method manipulates x
inplace, so that we are in fact calculating torch.sigmoid(x)*torch.sigmoid(x)
, which is wrong!
While inplace methods might save some peak memory usage, I’m always careful about using them and always check the result.