I have a memory-hungry function (i.e. requires O(n^2) memory for O(n) input , but its output is only O(n)).
In order to increase the mini-batch size, I would like to free the intermediate activations, and recalculate them
in the backward pass. I have implemented an autograd Function, which also uses autograd inside its backward
method to calculate the gradients. However, when the .backward() methods gets called inside the implemented
autograd function, the program gets stuck. Here is the code:
import torch
from torch.autograd import Variable
import torch.nn as nn
class MemEfficientFun(torch.autograd.Function):
def __init__(self, fun):
super(MemEfficientFun, self).__init__()
self.fun = fun
def forward(self, *inputs):
self.save_for_backward(*inputs)
return self.fun(*inputs)
def backward(self, grad_output):
inputs = [Variable(x,requires_grad=True) for x in self.saved_tensors]
output = self.fun(*inputs)
# This is where the program gets stuck:
output.backward(grad_output)
grad_inputs = tuple([x.grad for x in inputs])
if len(grad_inputs) == 1:
grad_inputs, = grad_inputs
return grad_inputs
def memory_hungry_fun(x):
return torch.mm(x,x.t()).sum(1)
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.FC = nn.Linear(10,10)
self.Fun = MemEfficientFun(memory_hungry_fun)
def forward(self, x):
x = self.FC(x)
x = self.Fun(x)
return x.sum()
def test_net():
net = MyNet()
x = Variable(torch.randn(32,10))
y = net(x)
y.backward()
print("Success!")
if __name__ == "__main__":
test_net()
Is there any way I can make this work in PyTorch?
Thanks.