`weight.grad.data` is not necessary

One official tutorial has below comment

# An alternative way is to operate on **weight.data** and **weight.grad.data**.
# Recall that tensor.data gives a tensor that shares the storage with
# tensor, but doesn't track history.

https://github.com/pytorch/tutorials/blob/7f48ac827389f3438113c1604ac416cca3c1090e/beginner_source/examples_autograd/two_layer_net_autograd.py#L59-L72

The requires_grad of weight is True, we get a tensor sharing same underlying data with weight by weight.data, but this tensor’s requires_grad is False.

Since weight.grad is untracking tensor (requires_grad = False), It seems that weight.grad.data is same with weight.grad.

In a word, .data or .detach() is used for variable tensor i.e. requires_grad is True, for other tensor, it do nothing.

Please corret me.

1 Like

As you observe, in many circumstances, a.grad will not require gradient.
This is, however, not universally true:

a= torch.randn((5,5), requires_grad=True)
(a**2).sum().backward(create_graph=True)
print(a.grad.requires_grad)

Note that .detach() is universally better than .data, except when you use it as a shorthand for something involving with torch.no_grad(): (which is the case for optimizers).

Best regards

Thomas

2 Likes

Got it,

weight.grad.data can handle backward of backward scenario and don’t harm tensor that wont require gradient

Hi,

I am trying to manually update weights, while using gradients from autograd (I need this as a base to implement my custom weight update algorithm later). So, based on the discussion above, is the following a good practice in general : using .grad.data and .grad.data.zero_() instead of .grad and .grad.zero_() ?

Note : I am using named_parameters() for debugging purposes only , the name is not used for compuatations inside the loop.

    with torch.no_grad():
        for name, p in model.named_parameters():
            p.data = p.data - learning_rate*p.grad.data 
            # Manually zero the gradients after updating weights
            p.grad.data.zero_()

Regards, Sumit