How to use no_grad() for just a custom function

Lets say I have the following network:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))

        x = costumefuntion(x)

        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def costumefunction:
     with torch.no_grad():
            output =  a couple of functions and operations to do something stupid
            return output

can i design a costumes function and use it inside my network as above and disable the backprob just for my costumefunction like what i did?

my concern is that i want to disable backprob for the function and inside the function, but i dont want the backprob become disabled in the rest of the forward passt. meaning that i want to have back propagation for steps after x = costumefuntion(x) again.
is it even possible?


What torch.no_grad() means is that no gradients will be computed through any operations inside it.
In your case, if you want gradients to flow back to conv1, then you actually want to compute the gradients through costumefunction.
What are you trying to achieve here?

Thank you. that’s what i was confuse about exactly.
I want to gradient flow back to conv1 but i noticed that when i also compute the gradient inside the costumefunction it just makes the process super slow, so that is why i was wondering maybe i kill it inside the costumefunction.
Since i need the gradient to flow in the net, i geuss i have no choice to compute it inside the costumefunction as well…


One thing you can try to do if this function is simple but just too slow (because a lot of operations), you can try and convert it into a hand made autograd Function (see doc): basically, you will implement both the forward and the backward to replace the autodiff that can be slow if a lot of small operations are made.

1 Like