Simple `.backward()` override

Hi,

I would like to call a function right after the backward pass on a module. It seems like the backward function, which does not really makes sense because the module definitely learns and update it weights.

My module is something like:

class BaseModule(nn.Module):
    [...]

class MyModule(BaseModule):
    def fct():
        print("fct")
    def backward(self, *args, **kwargs):
        o = super(MyModule, self).backward(*args, **kwargs)
        self.fct()
        return o

In this case, I can put anything in fct, it wont change anything.
I don’t get why this function isn’t called.

Any ideas?

Hi,

This function is not called because nn.Module are completely differentiated with automatic differentiation.
If you want to be able to define both the forward and backward, you need to use an autograd.Function, but then you need to implement the full backward yourself.

If you want to call a function after the backward of a module has been computed, you should use backward_hook, see here in the documentation for how to use them.

3 Likes

Fast and accurate. Thanks a lot.