Recalculating activations in the background pass for memory efficiency

I have a memory-hungry function (i.e. requires O(n^2) memory for O(n) input , but its output is only O(n)).
In order to increase the mini-batch size, I would like to free the intermediate activations, and recalculate them
in the backward pass. I have implemented an autograd Function, which also uses autograd inside its backward
method to calculate the gradients. However, when the .backward() methods gets called inside the implemented
autograd function, the program gets stuck. Here is the code:

import torch
from torch.autograd import Variable
import torch.nn as nn

class MemEfficientFun(torch.autograd.Function):
    
    def __init__(self, fun):
        super(MemEfficientFun, self).__init__()
        self.fun = fun

    def forward(self, *inputs):
        self.save_for_backward(*inputs)        
        return self.fun(*inputs)

    def backward(self, grad_output):
        inputs = [Variable(x,requires_grad=True) for x in self.saved_tensors]
        output = self.fun(*inputs)        
        # This is where the program gets stuck: 
        output.backward(grad_output)
        grad_inputs = tuple([x.grad for x in inputs])        
        if len(grad_inputs) == 1:
            grad_inputs, = grad_inputs        
        return grad_inputs

def memory_hungry_fun(x):
    return torch.mm(x,x.t()).sum(1)
    
class MyNet(nn.Module):
    
    def __init__(self):        
        super(MyNet, self).__init__()
        self.FC = nn.Linear(10,10)        
        self.Fun = MemEfficientFun(memory_hungry_fun)
        
    def forward(self, x):
        x = self.FC(x)
        x = self.Fun(x)
        return x.sum()

def test_net():    
    net = MyNet()
    x = Variable(torch.randn(32,10))
    y = net(x)
    y.backward()
    print("Success!")
    
    
if __name__ == "__main__":
    test_net()

Is there any way I can make this work in PyTorch?

Thanks.

Hi,

This is a known issue that is tracked here #1776.
This would be a complex change so I don’t think this is high priority right now.

2 Likes