Combining functional.jvp with a nn.Module

Hi everyone! For my research, I have the need to compute a Jacobian-vector product (JVP), where the Jacobian is of the outputs of a nn.Module (on a mini-batch) w.r.t. its parameters. Basically, the output of the JVP has the same size as the outputs of the original network.

I installed PyTorch 1.5 because of the new functional.jvp. The input to jvp must be a function with tensor inputs/outputs (in my case: the inputs are all the Parameters of a nn.Module). However, PT does not have a functional implementation of a Module, where the call to the module takes the parameters as arguments (like in JAX).

In this case, I would need to pass the parameters explicitly, then copy them inside the module every time the function is used, which looks inefficient.

Is there a simpler way to combine the new interface with a nn.Module? Or is there a simpler way to achieve this JVP which I am overseeing? Thanks!

Hi,

Yes the nn.Module construction makes it quite hard to be functional as it is based on the fact that the parameters are part of the state.

But here you can cheat by removing the parameters from the module and setting the new Tensors one by one before the forward. An example is below, you should re-organize it if you want to use it in real code to allow restoration of the nn.Parameter I think.

import torch

def del_attr(obj, names):
    if len(names) == 1:
        delattr(obj, names[0])
    else:
        del_attr(getattr(obj, names[0]), names[1:])
def set_attr(obj, names, val):
    if len(names) == 1:
        setattr(obj, names[0], val)
    else:
        set_attr(getattr(obj, names[0]), names[1:], val)

def make_functional(mod):
    orig_params = tuple(mod.parameters())
    # Remove all the parameters in the model
    names = []
    for name, p in list(mod.named_parameters()):
        del_attr(mod, name.split("."))
        names.append(name)
    return orig_params, names

mod = torch.nn.Linear(1, 10)
orig_params, names = make_functional(mod)
# mod.parameters() is empty now

def functional_mod_fw(*params):
    for name, p in zip(names, params):
        set_attr(mod, name.split("."), p)
    return mod(inp)

inp = torch.rand(1, 1)

v = []
for p in orig_params:
    v.append(torch.rand(p.size()))

out = torch.autograd.functional.jvp(functional_mod_fw, orig_params, v=tuple(v))
print(out)
2 Likes

Thanks! It looks a little strange but it should work. :slightly_smiling_face:

Supporting the nn.Module approach would mean going down the same issue as the checkpointing module we have. Which has severe limitations forcing the user to user backward() and can lead to unexpected behavior wrt things requiring gradients. :confused:

I have tested the solution, also including a call to autograd.grad with respect to v, and it seems to work very nicely, thanks!

Out of curiosity: in my case I need to optimize w.r.t. v, meaning that I have to call the JVP repeatedly for separate values of v. In JAX we have linearize, which is basically a curried version of the JVP:

https://jax.readthedocs.io/en/latest/jax.html#jax.linearize

Do you think it would be feasible to do something similar in the new functional interface?

I’m afraid we don’t have a good solution for this atm :confused:

When vmap is out, you will be able to use that and vmap over the v argument.

If you’re happy with a hacky way, I can offer that you take the function from the autograd/functional.py file here and modify it as follows:

import torch
from torch.autograd.functional import _as_tuple, _grad_preprocess, _check_requires_grad, _validate_v, _autograd_grad, _fill_in_zeros, _grad_postprocess, _tuple_postprocess

def fw_linearize(func, inputs, create_graph=False, strict=False):
    is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
    inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)


    outputs = func(*inputs)
    is_outputs_tuple, outputs = _as_tuple(outputs, "outputs of the user-provided function", "jvp")
    _check_requires_grad(outputs, "outputs", strict=strict)
    # The backward is linear so the value of grad_outputs is not important as
    # it won't appear in the double backward graph. We only need to ensure that
    # it does not contain inf or nan.
    grad_outputs = tuple(torch.zeros_like(out, requires_grad=True) for out in outputs)

    grad_inputs = _autograd_grad(outputs, inputs, grad_outputs, create_graph=True)
    _check_requires_grad(grad_inputs, "grad_inputs", strict=strict)

    def lin_fn(v, retain_graph=True):
        if v is not None:
            _, v = _as_tuple(v, "v", "jvp")
            v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
            _validate_v(v, inputs, is_inputs_tuple)
        else:
            if len(inputs) != 1 or inputs[0].nelement() != 1:
                raise RuntimeError("The vector v can only be None if the input to "
                                   "the user-provided function is a single Tensor "
                                   "with a single element.")

        grad_res = _autograd_grad(grad_inputs, grad_outputs, v, create_graph=create_graph, retain_graph=retain_graph)

        jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")

        # Cleanup objects and return them to the user
        jvp = _grad_postprocess(jvp, create_graph)

        return _tuple_postprocess(jvp, is_outputs_tuple)
    return lin_fn


def my_fun(x):
    return x.pow(3).sum()

inp = torch.ones(4)

lin = fw_linearize(my_fun, inp)

v = torch.zeros(4)
print(lin(v))
v = torch.ones(4)
print(lin(v))
v = torch.tensor([1., 0., 0., 0.])
print(lin(v))
v = torch.tensor([0., 1., 0., 0.])
print(lin(v))
v = torch.tensor([0., 0., 1., 0.])
print(lin(v))
v = torch.tensor([0., 0., 0., 1.])
print(lin(v))





1 Like

Thanks, it works nicely! Looking forward to the next improvements to the functional module.

1 Like