Backward of non-sequential module

Hi everyone,

I would like to implement a module that I can perform some operations to intermediate layers such that:

(1) requires_gradient = False for these operations a, b

(2) requires_gradient = True for out

(3) force PyTorch to build Autograd graph such that the order is: input -> out (conv1) -> out (relu1). It means that PyTorch won’t need to take care of computing gradients for any function_a and function_b in the backward pass. The reason is: function_a and function_b were not provided by PyTorch, and it is very hard to compute their derivatives.

I am not sure whether it is feasible or not. Do you have any ideas how to proceed?

Thank you so much for reading.

A simple network can be seen as follows:

class CustomModule(nn.Module):
    def __init__(self, module):
        super(CustomModule, self).__init__()
        self.module = module

    def forward(self, input):
        a = function_a(input)       # no need gradient in backward
        out = self.module(input)    # need gradient in backward
        b = function_b(out)         # no need gradient in backward
        return b

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = CustomModule(nn.Conv3d(1, 100, kernel_size=1))
        self.relu1 = CustomModule(nn.ReLU())

    def forward(self, input):
        conv1 = self.conv1(input)
        relu1 = self.relu1(conv1)
        return relu1


The gradients are computed using backpropagation. So to get the gradients for a module, you need all use of its output to be done by differentiable functions.
If you do:

a = func1(input)
b = module(a)
c = func2(b)
loss = crit(c)

And your learnable parameters are in module.
You will need func2 and crit to be differentiable. But not func1.

Hi @albanD,

Thank you so much for your response.

Is it true that if my learnable parameters are in module1 and module2, and I do:

a1 = func1(input)
b1 = module1(a1)
c1 = func2(b1)

a2 = func1(c1)
b2 = module2(a2)
c2 = func2(b2)

loss = crit(c2)

(1) I will need only critand func2 to be differentiable? How does backward-pass work in the aforementioned scenario?

(2) As I understand, my problem is a kind of disconnected graph since in the backward pass I just want to compute the gradients of loss -> b2 -> b1. Is there any ways to avoid Autograd and do the back-propagation manually?


  1. no if you chain them, then func1 will also need to be differentiable to get gradients for module1.
  2. You can use the information in this part of the doc to create your own autograd Function. In your case, you can use func1 and func2 for the forwards and you will need to implement the backward for them.
1 Like

Thank you so much @albanD