Post-processing model output while retaining gradients

Hello,

I am trying to train a model in the following fashion:

# in a training loop, getting a batch x and a target y_true
output = model(x)
output_processed = process_output(output)
loss = criterion(output_processed, y_true)
loss.backward()
optimizer.step()

I am facing issues implementing the process_output function. Some precisions on the context:

  • my output tensor is of shape (batch_size, output_size)
  • my output_processed tensor is of shape (batch_size, some_other_size) where some_other_size > output_size
  • the internals of process_output use slicing to assign values, for instance output_processed[:, (0, 1, 2)] = output[:, (0, 1, 2)] + output[:, (3, 4, 5)] (operations may be more complex but they all consist in torch elementary operations + slicing & broadcasting tricks)

Now, here is the issue I’m facing. I want to update the model’s weights using the loss computed from processed_output, so process_output begins like this:

def process_output(output: Tensor) -> Tensor:
    output_processed = (
        torch.empty(size=(output.shape[0], some_other_size))
    ).to(output.device, output.dtype).requires_grad_(True)

    # perform operations: this raises RuntimeError
    output_processed[:, (0, 1, 2)] = output[:, (0, 1, 2)] * 2
    # ... other operations of the sort ...

    return output_processed

As commented in the function above, the first operation fails and raises RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation. I am not sure how to proceed here because I don’t think I understand requires_grad and leaf variables correctly.

From my understanding, setting requires_grad_(False) when instantiating the tensor should “break” my gradient computation, that is calling loss.backward() will result in a wrong update of the model’s weights because I’m not “tracking” the gradients in output_processed. However, a user replied on a similar post that he managed to correctly update his model: Leaf variable was used in an inplace operation - #14 by nima_rafiee.

The more I read on leaf variables, the more I’m confused. The docs for torch.Tensor.is_leaf say that “Tensors that have requires_grad which is False will be leaf Tensors”, but then say “Only leaf Tensors will have their grad populated during a call to backward()”. So that means that you say “that Tensor does not track gradients”, but then that Tensor is a leaf Tensor so it has its gradient populated ?

So, should I set requires_grad_(True) for output_processed or not ? If yes, would creating different Tensors from the operations in process_output and then concatenating them at the end, rather than slicing an already instantiated Tensor, be a good solution to deal with the error ? Thanks !

Hi Louis!

Just implement process_output() using (differentiable) pytorch tensor
functions and everything should work automatically. (Note, slicing counts
as a differentiable tensor function.)

As far as autograd and updating a model’s weights are concerned, there
is nothing different about tensor functions that occur inside of a model,
and and tensor functions that occur outside of a model such as those in
a post-processing step.

Consider:

>>> import torch
>>> print (torch.__version__)
2.0.0
>>>
>>> param = torch.ones (2, 3, requires_grad = True)   # some parameter in your model
>>> output = 2 * param                                # dummy output of model
>>> output_processed = torch.zeros (2, 4)
>>> output_processed                                  # no connection to gradients yet
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.]])
>>> output_processed.requires_grad                    # starts out False
False
>>> output_processed[:, 1] = output[:, 0]             # use slice assignment
>>> output_processed                                  # now carries a grad_fn so you can backpropagate
tensor([[0., 2., 0., 0.],
        [0., 2., 0., 0.]], grad_fn=<CopySlices>)
>>> output_processed.requires_grad                    # "inherited" True from param -- wasn't set explicitly
True
>>> loss = output_processed.sum()                     # some dummy loss
>>> loss.backward()                                   # gradients backpropagate through output_processed
>>> param.grad                                        # to param.grad
tensor([[2., 0., 0.],
        [2., 0., 0.]])

Best.

K. Frank

1 Like

Hi Frank ! Thank you very much for your answer.

I noticed while debugging that indeed, even when explicitly specifying requires_grad_(False) my tensor inherited the requires_grad property of my model output, but I didn’t see the gradients propagating explicitly like you showed here. You saved me quite a bit of refactoring :smile:

Best,
Louis