Recommended way of updating to new Function interface?

Sorry if I missed some of the other discussions on this – In a lot of my projects I have custom PyTorch Functions defined and have been getting a lot of deprecation warning errors about updating to the staticmethod interface. Here’s an example of one with a lot of attributes that aren’t necessarily Tensors that I would like to access in the forward/backward pass and I would like to return some information from inside of the forward pass that also aren’t Tensors, which I don’t know how to do with the new interface.

Here’s another example that I’m thinking about updating: http://github.com/locuslab/qpth/blob/master/qpth/qp.py

And here are others that I have updated that I’ve done quite messily by passing all of the attributes into the forward method and putting them back into the context – a hack to get nearly the same functionality as the old interface. Is returning None in the backward pass for all of the additional flags the right thing to do?

Do you have any recommendations/thoughts before I start making more updates for this? Also how should I return auxiliary information that will never be backpropped through from a Function call?

\cc @albanD

Hi Brandon,

Hm. We should have made clear that is an old interface earlier - the “new” Function interface has been that way since PyTorch 0.2, so it is bad that you’re surprised…

I think both are the right thing to do.

If you’re in for a hack, you can make good use of closures to eliminate the administration overhead. If things go well, that will also make it easier to script-enable your functions.

Best regards

Thomas

Thanks! Very useful – I like that interface and may start using something like that for new functions in the future. I may take a shot at also trying to add in another level of sugar to handle non-Tensor inputs/outputs with an interface like that.

Also I just played around a bit more with the new Function interface and modified the ReLU example – is it ok to manually instantiate the object so I can set some attributes on it and collect some stats in it that I never need to backprop through? (Arbitrarily using numpy here to make sure it works for non-trivial objects that aren’t tensors)

import torch
import numpy as np

class MyReLU(torch.autograd.Function):
    def __init__(ctx, extra_inputs):
        ctx.extra_inputs = extra_inputs
    
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        print('extra_inputs:', ctx.extra_inputs)
        ctx.extra = {'some_stats': np.random.randn(10)}
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
    
ctx = MyReLU(extra_inputs=[np.random.randn(10)])
x = torch.randn(10, requires_grad=True)
y = MyReLU.forward(ctx, x)
print(ctx.extra)
print(torch.autograd.grad(y.sum(), x))

Note that ctx is not an instance of Function but rather a “backward context” object, so this won’t work. What you could do is a thing like

def MyFunc(extra_inputs):
    class MyFuncFn(torch.autograd.Function):
        @staticmethod
        def foward(ctx, input):
            ...
        ...
    return MyFuncFn.apply

or so, but I’d have doubts whether that’s really worth it (relative to using functools.partial or so).
I tend to use mutuable types (e.g. dictionaries) for stats.

Best regards

Thomas

Ah I see – in my example above my torch.autograd.grad call was just relying on autodiff on the forward method and never calling into the backward method.

And thanks! Nesting the Function inside a Python function is a really clean and easy way of updating my older code and not something I considered before – I updated qpth to do this in a few minutes and I think it’ll work well for the mpc one too with some slight modifications :slight_smile: