Implementation help for Memory Monger (Memory optimisation ) for pytorch


I want to implement memory monger for pytorch. Any suggestions on how to do this ?

My question is, this memory monger method is suitable for symbolic graphs like the ones in Theano and Tensorflow, but the graphs in pytorch are dynamic, so can this method be applied for pytorch ?

PS: Link for 1) memory monger paper and 2)implementation based on mxnet framework.


you dont know the graphs ahead of time, so you can do memory monging.
Once we write some more documentation on our torch.jit package, you can implement memonging as a JIT pass.

Just to clarify, you are saying that until the documentation of torch.jit is out I can’t implement memonging. ?


I think a way to implement it would be:
Some caveats for the following code:

  • Support only single backward
  • Right now can only accept one tensor as input for the child module, but this could be easily changed
import torch
from torch import autograd
from torch.autograd import Variable, Function
from torch.autograd.function import once_differentiable
from torch import nn

class MongerFunction(Function):
    def forward(ctx, inp, mod):
        ctx.mod = mod
        # assume inp is a single tensor, but we could repackage all elements from a list here
        inp_var = Variable(inp, volatile=True)
        out_var = mod(inp_var)

    def backward(ctx, grad_out):
        inp, = ctx.saved_tensors
        inp_var = Variable(inp, requires_grad=True)
        # Recompute the forward graph
        out_var = ctx.mod(inp_var)
        # Accumulate gradients in params inside mod
        # Get the grad wrt input
        return, None

class Monger(nn.Module):
    def __init__(self, mod):
        super(Monger, self).__init__()
        self.child = mod

    def forward(self, inp):
        # Force requires grad because parameters inside self.child
        # might require grad but would not be seen from outside
        inp.requires_grad = True
        return MongerFunction().apply(inp, self.child)

print("testing with custom function")
def to_mong(inp):
    out = inp * 2
    return out.sum()

inp = Variable(torch.randn(100, 100), requires_grad=True)

monged = Monger(to_mong)
not_monged = to_mong

# Test monged
out = monged(inp)
out_monged =
if inp.grad is not None:
grad_monged =

# Test not monged
out = to_mong(inp)
out_not_monged =
if inp.grad is not None:
grad_not_monged =

print("output difference")
print("input grad difference")

print("\ntesting with nn.Module")
to_mong = nn.Sequential(
    nn.Linear(100, 50),
    nn.Linear(50, 1)

inp = Variable(torch.randn(1, 100))

monged = Monger(to_mong)
not_monged = to_mong

# Test monged
out = monged(inp)
out_monged =
grad_monged = {}
for name, p in monged.named_parameters():
    grad_monged[name.replace('child.', '')] =

# Test not monged
out = to_mong(inp)
out_not_monged =
grad_not_monged = {}
for name, p in to_mong.named_parameters():
    grad_not_monged[name] =

print("output difference")
print("params grad difference")
for name in grad_monged:
    print("diff for {} is {}".format(name, (grad_monged[name]-grad_not_monged[name]).abs().sum()))


@smth any update on torch.jit documentation ?