How am I supposed to cache Tensors that require grad but arent learnable Module params that are replaced by a different Tensor several times each forward pass?

I’m not 100% sure if the question actually addresses my problem so let me restate it:
I am trying to get the Jacobian of an input batch to an RNN(nn.Module), but after calling output.grad() the input.grad is still None. I closed in on what I think the problem is with these strategies:

  1. I double checked the input and all of the model parameters require grad
  2. I went from the outputs (which have a grad_fn) back through the model to find the last param whose param.grad is None.

This last gradless param is the net.activs parameter which gets replaced on every RNN step with a new nn.Parameter instance containing the activation values:

class RNN(nn.Module):
def __init__(self,many,args):
    # ...
    self.requires_grad = None # set this appropriately before forward pass to know how to 
def reset(self, batch_size=None):
        """reset internal batch activation tensors which arent learnable model params"""
        if batch_size is None:
            batch_size = self.batch_size
        else:
            self.batch_size = batch_size
        if self.n_hidden > 0:
            self.activs = nn.Parameter(torch.zeros(batch_size, self.n_hidden, dtype=self.dtype))
        else:
            self.activs = None
        self.outputs = nn.Parameter(torch.zeros(batch_size, self.n_outputs, dtype=self.dtype))
      
def forward(self, inputs):
        inputs = torch.tensor(inputs, dtype=self.dtype)
        activs_for_output = self.activs
        if self.n_hidden > 0:
            # recurrent loopdeloop
            for _ in range(self.n_internal_steps):
                self.activs = nn.Parameter(self.activation(
                    self.hidden_responses * (
                        self.input_to_hidden(inputs) +
                        self.hidden_to_hidden(self.activs) +
                        self.output_to_hidden(self.outputs)
                        ) + self.hidden_biases)).requires_grad_(self.requires_grad)
            if self.use_current_activs:
                activs_for_output = self.activs 
        output_inputs = (self.input_to_output(inputs) +
                            self.output_to_output(self.outputs))
        if self.n_hidden > 0:
            output_inputs += self.hidden_to_output(activs_for_output)
        self.outputs = nn.Parameter(self.activation(
            self.output_responses * output_inputs + self.output_biases)).requires_grad_(self.requires_grad)
        return self.outputs

def main():
    net = RNN()
    net.requires_grad_(True)
    net.requires_grad = True

    # try to get jacobian of probe batch
    inputs = torch.rand(batch_size, features).requires_grad_(True)
    outputs = net(inputs)
    outputs.backward(torch.ones_like(outputs))
    
    # forward
    assert states.requires_grad == True
    assert net.activs.requires_grad == True
    assert net.outputs.requires_grad == True
    assert actions.requires_grad == True

    # backward
    assert net.outputs.grad is not None # net.outputs == outputs
    assert net.activs.grad is not None # <---- the error if I remove all nn.Parameter() wraps above
    assert states.grad is not None # <---- my error
    
    jacobian = states.grad.clone().detach().view(batch_size, -1).cpu().numpy()

I think the problem is that I replace the net.activs parameter with a new instance several times during net.forward, and refer to it on successive forward passes. Should I not set this parameter as an attribute of the model? If yes, how else should I cache it?

What I’ve tried:

  1. See the alternative error in above code: try not to use nn.Parameters, leading to the output having no grad
  2. replace self.activs(/outputs) assignments by self.activs[:], yielding: RuntimeError: leaf variable has been moved into the graph interior

I will try caching these Tensors somewhere else now.

Stopped suppressing warnings and got UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won’t be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor
instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

which Im now reading

I solved my problem by:

  1. Not making activs and outputs nn.Parameters
  2. Not assigning them as model attributes. Instead I added them as optional key word arguments in the forward method and returned the activations as well.
1 Like