# Accessing Module attributes from torch.autograd.Function.backward()

Hi,

I want to compute the per sample gradients in a linear layer in order to compute the variance.
The standard implementation for batched gradients computes the inner product over the batch dimension in `grad_output.t() @ input` : [ Out, BatchSize ] x [ BatchSize, In ] = [ Out, In ].
Instead of computing the inner product we can compute the outer product and obtain the per sample gradients to get a tensor of shape [ BatchSize, Out, In ].

To make the whole thing computationally faster, I first want to compute the outer product for the per sample gradients and provide the actual gradients as a tensor contraction over the BatchSize dimension.

My problem is that I can compute the outer product in the backward pass, but I can’t store it as an attribute of the weight because the autograd engine does some funky stuff as the id(tensor)'s aren’t the same in the backward pass and the forward pass.

If you run the code below (which is a modified version from the Tutorial ‘Extending PyTorch’) you will see that the references as obtained by Pythons `id()` function yield different ids for the weight of the linear layer in the forward and the backward pass.

``````import future, sys, os, datetime, argparse
# print(os.path.dirname(sys.executable))
import torch
import numpy as np
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

matplotlib.rcParams["figure.figsize"] = [10, 10]

import torch, torch.nn
from torch.nn import Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F

Tensor = torch.Tensor
FloatTensor = torch.FloatTensor

torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)

# @property

# Inherit from Function

# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
print(f"forward: {id(weight)=}")
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output

# This function has only a single output, so it gets only one gradient
@staticmethod
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

print(f'backward: {id(weight)=}')
# print([id(tmp) for tmp in [input, weight, bias]])
# print([id(tmp) for tmp in ctx.saved_tensors])
# print([id(tmp) for tmp in ctx.saved_variables])

# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
'''Stays the same as its the gradient wrt to the input'''
'''
input : 	[BS, In ]
weight :	[Out, In]
'''
grad_weight_batch = grad_output.unsqueeze(-1).matmul(input.unsqueeze(-2)) * input.shape[0] # [BS, Out, 1] x [BS, 1, In] = [BS, Out, In]
grad_weight = grad_output.t().mm(input) # [Out, BS] x [Bs, In] = [Out, In]
if bias is not None and ctx.needs_input_grad[2]:

def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features

# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = Parameter(torch.Tensor(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)

# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if self.bias is not None:
self.bias.data.uniform_(-0.1, 0.1)

def forward(self, input):
# See the autograd section for explanation of what happens here.

def extra_repr(self):
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)

# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.

# print(f'linear weight: {id(linear.weight)}')

input = torch.randn(27,11)
output = linear(input)
output.sum().backward()

# print(f'linear weight: {id(linear.weight)=}')
# print(f'linear weight: {id(linear.weight.data)=}')

``````

My question is therefore, how can I compute additional quantities in the backward pass (such as the per sample gradients) and store the conveniently very close to the parameters as an i.e. attribute of the Parameter object?

I got a version running that satisfies my needs.

The trick is to pass the module itself in the forward function such that it can be called by reference.

The main issue was that although I could compute the per sample gradients during the backward call, there was no reference to the parent module such that I could store the computed values from the backward call in the relevant module.

That allows us to set attributes in the module and its respective parameters during the backward call.

``````import future, sys, os, datetime, argparse
# print(os.path.dirname(sys.executable))
import torch
import numpy as np
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

matplotlib.rcParams["figure.figsize"] = [10, 10]

import torch, torch.nn
from torch.nn import Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F

Tensor = torch.Tensor
FloatTensor = torch.FloatTensor

torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)

# @property

# Inherit from Function

# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None, module=None):
ctx.save_for_backward(input, weight, bias)
ctx.module = module
# print(f"forward func: {id(weight)=}")
# print(f"forward func: {id(ctx.to_save[1])=}")

output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output

# This function has only a single output, so it gets only one gradient
@staticmethod
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors

# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
'''Stays the same as its the gradient wrt to the input'''
'''
input : 	[BS, In ]
weight :	[Out, In]
'''
grad_weight_batch = grad_output.unsqueeze(-1).matmul(input.unsqueeze(-2)) * input.shape[0] # [BS, Out, 1] x [BS, 1, In] = [BS, Out, In]
# grad_weight = grad_output.t().mm(input) # [Out, BS] x [Bs, In] = [Out, In]
# print(weight is global_weight)
if bias is not None and ctx.needs_input_grad[2]:

def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features

# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = Parameter(torch.Tensor(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)

# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if self.bias is not None:
self.bias.data.uniform_(-0.1, 0.1)

def forward(self, input):
# See the autograd section for explanation of what happens here.
print(f"forward: {id(self.weight)=}")

def extra_repr(self):
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)

# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.

print(f'linear weight: {id(linear.weight)}')

input = torch.randn(27,11)
output = linear(input)
output.sum().backward()