Recording history gradients of activations

Is there an accepted way to store a history of gradients of activations in a network?

The documentation for modifying the backward operation of a linear layer implies that I cannot store persistent states outside of the forward() and backward() ctx-based communication channel.

https://pytorch.org/docs/stable/notes/extending.html

# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

Ideally, if these were not @staticmethod, it would be possible to store grad_input, grad_output directly in a specific instance of the Function class with its own memory store. Something like:

class LinearFunction(Function):
    def __init__(self):
        self.grad_history = []

    def forward(self, ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    def backward(self, ctx, grad_output):
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        self.grad_history.append(grad_output)
        return grad_input, grad_weight, grad_bias

try hooking tensors of interest with:

def _on_gradient_of_watched(name,gr):
	gr = gr.detach().cpu().clone()
	...

def watch_tensor(name: str, te) -> None:
	if te.requires_grad:
		te.register_hook(lambda gr: _on_gradient_of_watched(name, gr))

This is a good idea. I had not considered hooks. Considering my next step, I would also like to modify the actual autograd backward() based on this recorded history. Ideally, it would just be more encapsulated to the particular Function instance.

You can just add auxiliary inputs and/or outputs to Function, adding data paths without gradient flow.

Also is there a reliable way to find out the order of the grad_input tuple? Module — PyTorch 1.9.0 documentation I believe the grad_input tuple for a linear layer is (dy/d_bias, dy/d_input, dy/d_weight) but for other layers, I would like a way to reliable way to map this.

I’m getting conflicting documentation from multiple sources about the arguments that the hook function receives (from third-party documentation).

Sorry, I haven’t used that API, and implementation looks convoluted.

Is there a clean way to get the corresponding input activation from the hook_fn()? Right now, I am collecting the input by wrapping the forward() function which seems prone to mistakes.

Normally, if I were to rewrite the entire layer, I could store the corresponding input via ctx.saved_tensors:

def backward(self, ctx, grad_output):
    input, weight, bias = ctx.saved_tensors

“input”? no, because it can potentially be released (garbage collected) before the backward pass, it is not auto-stored anywhere in the general case

Ah, that makes sense from a memory-saving point of view for inference. But the input is needed for backpropagation to update the weight matrices. So it has to be kept somewhere during training? Presumably it would be in the

module

part of:

def hook_fn(module, grad_input, grad_output)

when some Function needs to keep an input tensor, it is stored locally in saved_tensors - this is the same for built-in operators, they just use c++ version of Function. There is an obvious disconnect between Module and Function, plus some gradients are computed without storing function inputs.

For module-level hooks, it would be way easier to exploit python’s dynamism and additionally hook Module.__call__ to record inputs. In fact, I think Module.register_forward_pre_hook exists for this usecase (again, I haven’t used it in practice).