Torch.no_grad() vs detach()

I understand that when using the no_grad() environment, the Autograd does not keep track of the computation graph and it’s similar to temporarily setting requires_grad to False whereas the detach() function returns a tensor which is detached from the computation graph.

My question is, is there any place where using detach() is necessary? It seems to me that we can always do everything using no_grad() instead.

Thanks

There is quite a bit of fine print to this rough “they have the same effect”. For example:

  • even in no_grad-mode views will be tracked (and have requires_grad set if they are views of a tensor that has).
  • detach is a more versatile operation in that you can control what you want to not have gradients (e.g. if you want to train only the last (few) layer(s) for fine-tuning).

@tom
I met a similar problem, I have a code like follows, only last frame will do bp and calculate loss. But all frames use the same parameters.
It works with find_unused_parameters=True, but the weights cannot be updated. I need the weights to be updated by last frame’s bp. In this case, what I can do is using detach()?

class DummyDataset(torch.utils.data.Dataset):

    def __init__(self, seq_len=5):
        self.seq_len = seq_len
        

    def __len__(self):
        return 200

    def __getitem__(self, idx):
        return np.random.rand(self.seq_len, 10).astype(np.float32), np.random.rand(1).astype(np.float32)
        
class DummyNet(nn.Module):

    def __init__(self):
        super().__init__()

        self.mlp_fea = nn.Linear(10, 10)
        self.mlp_out = nn.Linear(10, 1)
    
    def forward(self, inputs, labels, training=False):
        inputs = inputs.cuda()
        labels = labels.cuda()

        B, SEQ, C = inputs.shape
        print(self.mlp_fea.weight)
        for i in range(SEQ):
            # print('seq: ', i)
            x = inputs[:, i]
            if i == 0:
                fea_prev = torch.zeros_like(x)

            if i < SEQ - 1:
                self.eval()
                with torch.no_grad():
                    fea_prev = self.mlp_fea(x + fea_prev)
                self.train()
            else:
                fea_prev = self.mlp_fea(x + fea_prev)
                out = self.mlp_out(fea_prev)
            if i == SEQ - 1:
                loss = nn.L1Loss(reduction='mean')(out, labels)
                return loss