I want to implement memory monger for pytorch. Any suggestions on how to do this ?
My question is, this memory monger method is suitable for symbolic graphs like the ones in Theano and Tensorflow, but the graphs in pytorch are dynamic, so can this method be applied for pytorch ?
PS: Link for 1) memory monger paper and 2)implementation based on mxnet framework.
you dont know the graphs ahead of time, so you can do memory monging.
Once we write some more documentation on our torch.jit package, you can implement memonging as a JIT pass.
I think a way to implement it would be:
Some caveats for the following code:
Support only single backward
Right now can only accept one tensor as input for the child module, but this could be easily changed
import torch
from torch import autograd
from torch.autograd import Variable, Function
from torch.autograd.function import once_differentiable
from torch import nn
class MongerFunction(Function):
@staticmethod
def forward(ctx, inp, mod):
ctx.mod = mod
ctx.save_for_backward(inp)
# assume inp is a single tensor, but we could repackage all elements from a list here
inp_var = Variable(inp, volatile=True)
out_var = mod(inp_var)
return out_var.data
@staticmethod
@once_differentiable
def backward(ctx, grad_out):
inp, = ctx.saved_tensors
inp_var = Variable(inp, requires_grad=True)
# Recompute the forward graph
out_var = ctx.mod(inp_var)
# Accumulate gradients in params inside mod
out_var.backward(grad_out)
# Get the grad wrt input
return inp_var.grad.data, None
class Monger(nn.Module):
def __init__(self, mod):
super(Monger, self).__init__()
self.child = mod
def forward(self, inp):
# Force requires grad because parameters inside self.child
# might require grad but would not be seen from outside
inp.requires_grad = True
return MongerFunction().apply(inp, self.child)
print("testing with custom function")
def to_mong(inp):
out = inp * 2
return out.sum()
inp = Variable(torch.randn(100, 100), requires_grad=True)
monged = Monger(to_mong)
not_monged = to_mong
# Test monged
out = monged(inp)
out_monged = out.data.clone()
if inp.grad is not None:
inp.grad.data.zero_()
out.backward()
grad_monged = inp.data.clone()
# Test not monged
out = to_mong(inp)
out_not_monged = out.data.clone()
if inp.grad is not None:
inp.grad.data.zero_()
out.backward()
grad_not_monged = inp.data.clone()
print("output difference")
print((out_monged-out_not_monged).abs().sum())
print("input grad difference")
print((grad_monged-grad_not_monged).abs().sum())
print("\ntesting with nn.Module")
to_mong = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU(),
nn.Linear(50, 1)
)
inp = Variable(torch.randn(1, 100))
monged = Monger(to_mong)
not_monged = to_mong
# Test monged
out = monged(inp)
out_monged = out.data.clone()
monged.zero_grad()
out.backward()
grad_monged = {}
for name, p in monged.named_parameters():
grad_monged[name.replace('child.', '')] = p.grad.data.clone()
# Test not monged
out = to_mong(inp)
out_not_monged = out.data.clone()
to_mong.zero_grad()
out.backward()
grad_not_monged = {}
for name, p in to_mong.named_parameters():
grad_not_monged[name] = p.grad.data.clone()
print("output difference")
print((out_monged-out_not_monged).abs().sum())
print("params grad difference")
for name in grad_monged:
print("diff for {} is {}".format(name, (grad_monged[name]-grad_not_monged[name]).abs().sum()))