Accessing Module attributes from torch.autograd.Function.backward()

Hi,

I want to compute the per sample gradients in a linear layer in order to compute the variance.
The standard implementation for batched gradients computes the inner product over the batch dimension in grad_output.t() @ input : [ Out, BatchSize ] x [ BatchSize, In ] = [ Out, In ].
Instead of computing the inner product we can compute the outer product and obtain the per sample gradients to get a tensor of shape [ BatchSize, Out, In ].

To make the whole thing computationally faster, I first want to compute the outer product for the per sample gradients and provide the actual gradients as a tensor contraction over the BatchSize dimension.

My problem is that I can compute the outer product in the backward pass, but I can’t store it as an attribute of the weight because the autograd engine does some funky stuff as the id(tensor)'s aren’t the same in the backward pass and the forward pass.

If you run the code below (which is a modified version from the Tutorial ‘Extending PyTorch’) you will see that the references as obtained by Pythons id() function yield different ids for the weight of the linear layer in the forward and the backward pass.

import future, sys, os, datetime, argparse
# print(os.path.dirname(sys.executable))
import torch
import numpy as np
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

matplotlib.rcParams["figure.figsize"] = [10, 10]

import torch, torch.nn
from torch.autograd.function import Function
from torch.nn import Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F

Tensor = torch.Tensor
FloatTensor = torch.FloatTensor

torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)


class GradBatch_Parameter(torch.nn.Parameter):

	def __init__(self, data, requires_grad=True):
		super(GradBatch_Parameter, self).__init__()

	# @property
	# def grad(self):
	# 	assert hasattr(self, 'grad_batch')
	# 	return self.grad_batch.mean(dim=0)

# Inherit from Function
class GradBatch_LinearFunction(Function):

	# Note that both forward and backward are @staticmethods
	@staticmethod
	# bias is an optional argument
	def forward(ctx, input, weight, bias=None):
		ctx.save_for_backward(input, weight, bias)
		print(f"forward: {id(weight)=}")
		output = input.mm(weight.t())
		if bias is not None:
			output += bias.unsqueeze(0).expand_as(output)
		return output

	# This function has only a single output, so it gets only one gradient
	@staticmethod
	def backward(ctx, grad_output):
		# This is a pattern that is very convenient - at the top of backward
		# unpack saved_tensors and initialize all gradients w.r.t. inputs to
		# None. Thanks to the fact that additional trailing Nones are
		# ignored, the return statement is simple even when the function has
		# optional inputs.
		input, weight, bias = ctx.saved_tensors
		grad_input = grad_weight = grad_bias = None

		print(f'backward: {id(weight)=}')
		# print([id(tmp) for tmp in [input, weight, bias]])
		# print([id(tmp) for tmp in ctx.saved_tensors])
		# print([id(tmp) for tmp in ctx.saved_variables])

		# These needs_input_grad checks are optional and there only to
		# improve efficiency. If you want to make your code simpler, you can
		# skip them. Returning gradients for inputs that don't require it is
		# not an error.
		if ctx.needs_input_grad[0]:
			'''Stays the same as its the gradient wrt to the input'''
			grad_input = grad_output.mm(weight)
		if ctx.needs_input_grad[1]:
			'''
			grad_output : 	[BS, Out]
			input : 	[BS, In ]
			weight :	[Out, In]
			'''
			grad_weight_batch = grad_output.unsqueeze(-1).matmul(input.unsqueeze(-2)) * input.shape[0] # [BS, Out, 1] x [BS, 1, In] = [BS, Out, In]
			grad_weight = grad_output.t().mm(input) # [Out, BS] x [Bs, In] = [Out, In]
			assert torch.allclose(grad_weight_batch.mean(dim=0), grad_weight)
			# weight.grad_batch = grad_weight_batch
			# grad_weight.batch = grad_weight_batch
		if bias is not None and ctx.needs_input_grad[2]:
			grad_bias = grad_output.sum(0)
			grad_bias_batch = grad_output * input.shape[0]
			assert torch.allclose(grad_bias_batch.mean(dim=0), grad_bias)
			# bias.grad_batch = grad_bias_batch
			grad_bias = grad_bias_batch

		return grad_input, grad_weight, grad_bias


class GradBatch_LinearModule(torch.nn.Module):
	def __init__(self, input_features, output_features, bias=True):
		super().__init__()
		self.input_features = input_features
		self.output_features = output_features

		# nn.Parameter is a special kind of Tensor, that will get
		# automatically registered as Module's parameter once it's assigned
		# as an attribute. Parameters and buffers need to be registered, or
		# they won't appear in .parameters() (doesn't apply to buffers), and
		# won't be converted when e.g. .cuda() is called. You can use
		# .register_buffer() to register buffers.
		# nn.Parameters require gradients by default.
		self.weight = Parameter(torch.Tensor(output_features, input_features))
		if bias:
			self.bias = Parameter(torch.Tensor(output_features))
		else:
			# You should always register all possible parameters, but the
			# optional ones can be None if you want.
			self.register_parameter('bias', None)

		# Not a very smart way to initialize weights
		self.weight.data.uniform_(-0.1, 0.1)
		if self.bias is not None:
			self.bias.data.uniform_(-0.1, 0.1)

	def forward(self, input):
		# See the autograd section for explanation of what happens here.
		return GradBatch_LinearFunction.apply(input, self.weight, self.bias)

	def extra_repr(self):
		# (Optional)Set the extra information about this module. You can test
		# it by printing an object of this class.
		return 'input_features={}, output_features={}, bias={}'.format(
			self.input_features, self.output_features, self.bias is not None
		)


linear = GradBatch_LinearFunction.apply

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20, 20, dtype=torch.double, requires_grad=True), torch.randn(30, 20, dtype=torch.double, requires_grad=True))
# test = gradcheck(GradBatch_LinearFunction.apply, input, eps=1e-6, atol=1e-4)

linear = GradBatch_LinearModule(11, 12, True)
# print(f'linear weight: {id(linear.weight)}')

input = torch.randn(27,11)
output = linear(input)
output.sum().backward()

# print(linear.weight.grad.shape)
# print(f'linear weight: {id(linear.weight)=}')
# print(f'linear weight: {id(linear.weight.data)=}')
# print(linear.weight.grad.shape)


My question is therefore, how can I compute additional quantities in the backward pass (such as the per sample gradients) and store the conveniently very close to the parameters as an i.e. attribute of the Parameter object?

I got a version running that satisfies my needs.

The trick is to pass the module itself in the forward function such that it can be called by reference.

The main issue was that although I could compute the per sample gradients during the backward call, there was no reference to the parent module such that I could store the computed values from the backward call in the relevant module.

That allows us to set attributes in the module and its respective parameters during the backward call.

import future, sys, os, datetime, argparse
# print(os.path.dirname(sys.executable))
import torch
import numpy as np
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

matplotlib.rcParams["figure.figsize"] = [10, 10]

import torch, torch.nn
from torch.autograd.function import Function
from torch.nn import Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F

Tensor = torch.Tensor
FloatTensor = torch.FloatTensor

torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)


class GradBatch_Parameter(torch.nn.Parameter):

	def __init__(self, data, requires_grad=True):
		super(GradBatch_Parameter, self).__init__()

	# @property
	# def grad(self):
	# 	assert hasattr(self, 'grad_batch')
	# 	return self.grad_batch.mean(dim=0)

# Inherit from Function
class GradBatch_LinearFunction(Function):

	# Note that both forward and backward are @staticmethods
	@staticmethod
	# bias is an optional argument
	def forward(ctx, input, weight, bias=None, module=None):
		ctx.save_for_backward(input, weight, bias)
		ctx.module = module
		# print(f"forward func: {id(weight)=}")
		# print(f"forward func: {id(ctx.to_save[1])=}")

		output = input.mm(weight.t())
		if bias is not None:
			output += bias.unsqueeze(0).expand_as(output)
		return output

	# This function has only a single output, so it gets only one gradient
	@staticmethod
	def backward(ctx, grad_output):
		# This is a pattern that is very convenient - at the top of backward
		# unpack saved_tensors and initialize all gradients w.r.t. inputs to
		# None. Thanks to the fact that additional trailing Nones are
		# ignored, the return statement is simple even when the function has
		# optional inputs.
		input, weight, bias = ctx.saved_tensors
		grad_input = grad_weight = grad_bias = None

		# These needs_input_grad checks are optional and there only to
		# improve efficiency. If you want to make your code simpler, you can
		# skip them. Returning gradients for inputs that don't require it is
		# not an error.
		if ctx.needs_input_grad[0]:
			'''Stays the same as its the gradient wrt to the input'''
			grad_input = grad_output.mm(weight)
		if ctx.needs_input_grad[1]:
			'''
			grad_output : 	[BS, Out]
			input : 	[BS, In ]
			weight :	[Out, In]
			'''
			grad_weight_batch = grad_output.unsqueeze(-1).matmul(input.unsqueeze(-2)) * input.shape[0] # [BS, Out, 1] x [BS, 1, In] = [BS, Out, In]
			# grad_weight = grad_output.t().mm(input) # [Out, BS] x [Bs, In] = [Out, In]
			# assert torch.allclose(grad_weight_batch.mean(dim=0), grad_weight)
			grad_weight = grad_weight_batch.mean(dim=0)
			# weight.grad_batch = grad_weight_batch
			# setattr(weight, 'grad_batch', grad_weight_batch)
			# ctx.module.weight.grad_batch = grad_weight_batch
			# print(weight is global_weight)
		if bias is not None and ctx.needs_input_grad[2]:
			# grad_bias = grad_output.sum(0)
			grad_bias_batch = grad_output * input.shape[0]
			# assert torch.allclose(grad_bias_batch.mean(dim=0), grad_bias)
			grad_bias = grad_bias_batch.mean(dim=0)
			# bias.grad_batch = grad_bias_batch
			# ctx.module.bias.grad_batch = grad_bias_batch

		return grad_input, grad_weight, grad_bias, None

class GradBatch_LinearModule(torch.nn.Module):
	def __init__(self, input_features, output_features, bias=True):
		super().__init__()
		self.input_features = input_features
		self.output_features = output_features

		# nn.Parameter is a special kind of Tensor, that will get
		# automatically registered as Module's parameter once it's assigned
		# as an attribute. Parameters and buffers need to be registered, or
		# they won't appear in .parameters() (doesn't apply to buffers), and
		# won't be converted when e.g. .cuda() is called. You can use
		# .register_buffer() to register buffers.
		# nn.Parameters require gradients by default.
		self.weight = Parameter(torch.Tensor(output_features, input_features))
		if bias:
			self.bias = Parameter(torch.Tensor(output_features))
		else:
			# You should always register all possible parameters, but the
			# optional ones can be None if you want.
			self.register_parameter('bias', None)

		# Not a very smart way to initialize weights
		self.weight.data.uniform_(-0.1, 0.1)
		if self.bias is not None:
			self.bias.data.uniform_(-0.1, 0.1)

	def forward(self, input):
		# See the autograd section for explanation of what happens here.
		print(f"forward: {id(self.weight)=}")
		return GradBatch_LinearFunction.apply(input, self.weight, self.bias, self)

	def extra_repr(self):
		# (Optional)Set the extra information about this module. You can test
		# it by printing an object of this class.
		return 'input_features={}, output_features={}, bias={}'.format(
			self.input_features, self.output_features, self.bias is not None
		)




linear = GradBatch_LinearFunction.apply

from torch.autograd import gradcheck

# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
# input = (torch.randn(20, 20, dtype=torch.double, requires_grad=True), torch.randn(30, 20, dtype=torch.double, requires_grad=True))
# test = gradcheck(GradBatch_LinearFunction.apply, input, eps=1e-6, atol=1e-4)
# test = torch.autograd.gradgradcheck(GradBatch_LinearFunction.apply, input, eps=1e-6, atol=1e-4)

linear = GradBatch_LinearModule(11, 12, True)
print(f'linear weight: {id(linear.weight)}')

input = torch.randn(27,11)
output = linear(input)
output.sum().backward()



# print(linear.weight.grad.shape)
# print(f'linear weight: {id(linear.weight)=}')
# print(f'linear weight: {id(linear.weight.data)=}')
print(linear.weight.grad_batch.shape)
print(torch.allclose(linear.weight.grad_batch.mean(dim=0), linear.weight.grad))
print(torch.allclose(linear.bias.grad_batch.mean(dim=0), linear.bias.grad))