Custom autograd.Function: backward pass not called

Good afternoon!

I’ve had this problem in my other thread already, but it isn’t really related, so I moved it to a new thread.

Whenever I try using a custom autograd.Function (i.e., a function that has an explicit backward pass defined), and I combine it with any torch.nn module, the backward pass is never properly executed.
Here’s an MWE containing a simple identity transform as an example:


import torch
import torch.autograd
import torch.nn



# custom autograd function
class Identity(torch.autograd.Function):
    def __init__(self):
        super(Identity, self).__init__()

    @staticmethod
    def forward(self, input):
        return input.clone()        # REASON FOR ERROR: forgot to .clone() here
    
    @staticmethod
    def backward(self, grad_output):
        print("backward")
        grad_input = grad_output.clone()
        return grad_input

identity = Identity.apply





input = torch.autograd.Variable(torch.randn(1,2,20,20).float(), requires_grad=True)
target = torch.autograd.Variable(torch.randn(1,2,20,20).float(), requires_grad=False)


# forward pass:
# - if any torch.nn function is called before the custom Function,
#   its backward pass is never executed.


intermediate = input                           # calling backward upon this works fine
# intermediate = torch.nn.ReLU()(input)        # uncomment this line and it won't work anymore


pred = identity(intermediate)


loss = torch.nn.MSELoss()(pred,target)
loss.backward()                                # this should always print "backward" to the command line, but it doesn't...

I am running the latest version of PyTorch (Anaconda says 0.2.0, py27hc03bea1_4cu80 [cuda80] soumith), with CUDA 8.0 and Python 2.7.

With some PyTorch versions this seemed to have worked, but now it won’t anymore.

Anybody knows what’s the matter, or else what I’m doing wrong?

Thank you very much for any answers!

EDIT: Silly me. I’ve been trying to solve this for quite some time, and literally 5 minutes after opening this thread I’ve discovered my mistake: I forgot to clone the inputs during the forward pass (code changed above), now it works.
Sorry for the spamming; thread can be closed.

If you don’t want to put the clone, you should mark that the input and output share some elements with the mark_shared_storage function like here.

1 Like

Hey could you explain me why i need to return input.clone()

@albanD: thanks a lot for the additional input! I’m quite fine with cloning, for my variables are not immensely huge.

@Akash_Goel: I’m obviously not an expert, but I assume the reason is the directed graph that is constructed behind the scene while running the forward pass on a variable. Every variable is linked to a predecessor and a successor. During the backward pass, this directed graph is traversed in inverted order.
If one doesn’t clone the variable, its executor function is probably just overwritten, which leads to the skipping of one of the building blocks. By cloning, however, a new variable is appended that is assigned to the custom block, and this preserves the execution order.
Correct me if I’m wrong.

1 Like

@Technics: I think you are correct since there is no operation performing on “input” no new operation has been added to the graph and since operation is not in the graph the backward function won’t be called by the graph during backpropogate :slight_smile:

My question in post

also related to custom autograd.Function and backward pass. Inspired by your code,
I modified my code and make it works now in the newer version of pytorch. However, I do not
clone the input during the forward pass, and it still works. Interesting, but also confusing me!

Had a similar problem. For me, mark_dirty worked perfectly. Makes backward getting called without needing to clone the input tensor or store pairs of tensors with mark_shared_storage.

class ManipulateGradient(Function):
	def forward(self, input):
		self.mark_dirty(input)
		return input

	def backward(self, grad_out):
		# manipulate gradient here
		return grad_out + 0.42

Does it really work for you?

from torch.autograd import Function
from torch import tensor

class ManipulateGradient(Function):

    def forward(self, input):
        self.mark_dirty(input)
        print("forward Function")
        return input

    def backward(grad_out):
        # manipulate gradient here
        print("backward Function")
        return grad_out + 0.42

if __name__ == "__main__":
    layer = ManipulateGradientModule()
    input = tensor([1.0, 2.0, 3.0], requires_grad=True)
    result = layer.forward(input)
    print("result: ", result)
    dout = tensor([0.1, -0.1, 0.2])
    result.backward(dout)
    print("result.grad: ", input.grad)

Gives me:

forward Function
Traceback (most recent call last):
    result = layer.apply(input)
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

Okay, let’s remove

self.mark_dirty(input)

Then we get:

forward Function
result:  tensor([1., 2., 3.], grad_fn=<FunctionBackward>)
Traceback (most recent call last):
  File "manipulate_gradient_non_static.py", line 22, in <module>
    result.backward(dout)
  File "python3.6/site-packages/torch/tensor.py", line 102, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "python3.6/site-packages/torch/autograd/__init__.py", line 90, in backward
    allow_unreachable=True)  # allow_unreachable flag
  File "python3.6/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File "python3.6/site-packages/torch/autograd/function.py", line 180, in backward
    raise NotImplementedError
NotImplementedError

Let’s write a proper Function with @staticmethods:

class ManipulateGradientStatic(Function):

    @staticmethod
    def forward(ctx, input):
        # ManipulateGradientStatic.mark_dirty(input)
        print("forward static")
        return input

    @staticmethod
    def backward(ctx, grad_out):
        # manipulate gradient here
        print("backward static")
        return grad_out + 0.42

class ManipulateGradientModule(torch.nn.modules.Module):
    def __init__(self):
        super(ManipulateGradientModule, self).__init__()

    def forward(self, input):
        print("forward Module")
        return ManipulateGradientStatic.apply(input)

if __name__ == "__main__":
    layer = ManipulateGradientModule()
    input = tensor([1.0, 2.0, 3.0], requires_grad=True)
    result = layer.forward(input)
    print("result: ", result)
    dout = tensor([0.1, -0.1, 0.2])
    result.backward(dout)
    print("result.grad: ", input.grad)

Gives:

forward Module
forward static
result:  tensor([1., 2., 3.], grad_fn=<ManipulateGradientStaticBackward>)
backward static
result.grad:  tensor([0.5200, 0.3200, 0.6200])

Is there any way to debug why my statcic backward method (much more complicated than the example above) is not called during backprop?

You can look at the chain of grad_fn and next_functions:

print(result.grad_fn)
print(result.grad_fn.next_functions)
print(result.grad_fn.next_functions[0][0].next_functions)
...

Usually the issue has something to do with either a detach() call (which stops backpropagation) or returning the input unmodified in forward. Try returning input.clone() to see if it fixes the issue.

1 Like

Hi, apologies for reviving an old thread, but even with the solutions indicated here I am still a bit confused. I have implemented a gradient reversal layer. I started by implementing a function:

# functional.py
from torch.autograd import Function


class RevGrad(Function):
    @staticmethod
    def forward(ctx, input_):
        ctx.save_for_backward(input_)
        output = input_.clone()
        return output

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = None
        if ctx.needs_input_grad[0]:
            grad_input = -grad_output.clone()
        return grad_input


revgrad = RevGrad.apply

I then use this function in a layer:

# __init__.py
from .functional import revgrad
from torch.nn import Module


class RevGrad(Module):
    def __init__(self, *args, **kwargs):
        """
        A gradient reversal layer.

        This layer has no parameters, and simply reverses the gradient
        in the backward pass.
        """

        super().__init__(*args, **kwargs)

    def forward(self, input_):
        return revgrad(input_)

This all works as expected, and the following test (run with pytest) succeeds:

# test_layer.py

import copy
import torch
from pytorch_revgrad import RevGrad


def test_gradients_inverted():
    network = torch.nn.Sequential(torch.nn.Linear(5, 3), torch.nn.Linear(3, 1))
    revnetwork = torch.nn.Sequential(copy.deepcopy(network), RevGrad())

    inp = torch.randn(8, 5)
    outp = torch.randn(8)

    criterion = torch.nn.MSELoss()
    criterion(network(inp), outp).backward()
    criterion(revnetwork(inp), outp).backward()
    assert all(
        (p1.grad == -p2.grad).all()
        for p1, p2 in zip(network.parameters(), revnetwork.parameters())
    )

So far so good. However, when I run this with the pytest-cov plugin, to check the coverage of my source code, it makes it look as if the backward call in my function is never called:

image

I’m very confused by this - the gradients do get reversed, as is clear from the test succeeding. But somehow this happens without the backward call being made? Am I missing something fundamental here?

1 Like

Hi,

I’m not sure how the coverage is checked, but if it checks if a python function calls the backward it will never happen. It is called directly by some cpp code. So the coverage tool might miss it.

1 Like

Thanks, that’s extremely helpful to know!

I think that is the issue, as far as I know coverage.py uses sys.settrace to do this, so that explains things.

Sorry to ask a silly question, but why is it necessary to have this layer in the first place. Could you not negate the loss function directly? Then the gradients would be flipped automatically, with no extra lines of code?

Hi, I am also faced with the similar issue. Have you got any solutions? Thanks.

Hi, the answer is up there - it’s being called, it’s just hidden from the python code because it’s called from C.

Sorry, only just saw this.

It’s for when you want to negate the gradients selectively for some parts of the network; multiplying the loss by -1 works but doesn’t give you any choice over which parts of the network get updated with the inverted gradients

1 Like