Combining functional.jvp with a nn.Module

Hi,

Yes the nn.Module construction makes it quite hard to be functional as it is based on the fact that the parameters are part of the state.

But here you can cheat by removing the parameters from the module and setting the new Tensors one by one before the forward. An example is below, you should re-organize it if you want to use it in real code to allow restoration of the nn.Parameter I think.

import torch

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

mod = torch.nn.Linear(1, 10)
orig_params, names = make_functional(mod)
# mod.parameters() is empty now

def functional_mod_fw(*params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)
    return mod(inp)

inp = torch.rand(1, 1)

v = []
for p in orig_params:
    v.append(torch.rand(p.size()))

out = torch.autograd.functional.jvp(functional_mod_fw, orig_params, v=tuple(v))
print(out)
2 Likes