Is there a way to compute weight gradients in Pytorch with no_grad

I know the question seems like it answers itself (compute weight gradients without gradients?). But the issue I am trying to resolve is that I will need to change leaf variables before computing the backwards pass. Without no_grad pytorch will complain about in-place modifications.

  1. Define the model (it is split in two stages of a pipeline):
class Pipe1(nn.Module):
      def __init__(self):
        super(Pipe1, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

      def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        return x
 class Pipe2(nn.Module):
     def __init__(self):
        super(Pipe2, self).__init__()
        
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

     def forward(self, x):
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        
        x = self.fc2(x)
        return F.log_softmax(x)

  1. Make the pipeline:
pipe = [Pipe1(), Pipe2()]

  1. Run a micro-batch through the pipeline:
with torch.no_grad():
    for batch_idx, (data, target) in enumerate(train_loader):
      output = pipe [1](pipe [0](data))
      
      loss = F.nll_loss(output, target)
      pipe[0].conv1.weight *= 3

And… what is four? Given the current loss I would ideally want to compute the derivatives with respect to the new updated weights. Is there any “proper” way of doing this in Pytorch? Would I need to compute the gradients myself?

The approach fails since the intermediate forward activations were not created by the manipulated weights and are thus wrong. You could try to directly pass the new gradients to this particular layer, by this sounds like a painfully manual approach.

I was afraid that would be the case. Thank you for your answer! What did you mean by pass the gradients manually exactly?