Hi,
I want to compute the per sample gradients in a linear layer in order to compute the variance.
The standard implementation for batched gradients computes the inner product over the batch dimension in grad_output.t() @ input
: [ Out, BatchSize ] x [ BatchSize, In ] = [ Out, In ].
Instead of computing the inner product we can compute the outer product and obtain the per sample gradients to get a tensor of shape [ BatchSize, Out, In ].
To make the whole thing computationally faster, I first want to compute the outer product for the per sample gradients and provide the actual gradients as a tensor contraction over the BatchSize dimension.
My problem is that I can compute the outer product in the backward pass, but I can’t store it as an attribute of the weight because the autograd engine does some funky stuff as the id(tensor)'s aren’t the same in the backward pass and the forward pass.
If you run the code below (which is a modified version from the Tutorial ‘Extending PyTorch’) you will see that the references as obtained by Pythons id()
function yield different ids for the weight of the linear layer in the forward and the backward pass.
import future, sys, os, datetime, argparse
# print(os.path.dirname(sys.executable))
import torch
import numpy as np
import matplotlib
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
matplotlib.rcParams["figure.figsize"] = [10, 10]
import torch, torch.nn
from torch.autograd.function import Function
from torch.nn import Module, Parameter
from torch.nn import Linear, Tanh, ReLU
import torch.nn.functional as F
Tensor = torch.Tensor
FloatTensor = torch.FloatTensor
torch.set_printoptions(precision=4, sci_mode=False)
np.set_printoptions(precision=4, suppress=True)
class GradBatch_Parameter(torch.nn.Parameter):
def __init__(self, data, requires_grad=True):
super(GradBatch_Parameter, self).__init__()
# @property
# def grad(self):
# assert hasattr(self, 'grad_batch')
# return self.grad_batch.mean(dim=0)
# Inherit from Function
class GradBatch_LinearFunction(Function):
# Note that both forward and backward are @staticmethods
@staticmethod
# bias is an optional argument
def forward(ctx, input, weight, bias=None):
ctx.save_for_backward(input, weight, bias)
print(f"forward: {id(weight)=}")
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
# This is a pattern that is very convenient - at the top of backward
# unpack saved_tensors and initialize all gradients w.r.t. inputs to
# None. Thanks to the fact that additional trailing Nones are
# ignored, the return statement is simple even when the function has
# optional inputs.
input, weight, bias = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
print(f'backward: {id(weight)=}')
# print([id(tmp) for tmp in [input, weight, bias]])
# print([id(tmp) for tmp in ctx.saved_tensors])
# print([id(tmp) for tmp in ctx.saved_variables])
# These needs_input_grad checks are optional and there only to
# improve efficiency. If you want to make your code simpler, you can
# skip them. Returning gradients for inputs that don't require it is
# not an error.
if ctx.needs_input_grad[0]:
'''Stays the same as its the gradient wrt to the input'''
grad_input = grad_output.mm(weight)
if ctx.needs_input_grad[1]:
'''
grad_output : [BS, Out]
input : [BS, In ]
weight : [Out, In]
'''
grad_weight_batch = grad_output.unsqueeze(-1).matmul(input.unsqueeze(-2)) * input.shape[0] # [BS, Out, 1] x [BS, 1, In] = [BS, Out, In]
grad_weight = grad_output.t().mm(input) # [Out, BS] x [Bs, In] = [Out, In]
assert torch.allclose(grad_weight_batch.mean(dim=0), grad_weight)
# weight.grad_batch = grad_weight_batch
# grad_weight.batch = grad_weight_batch
if bias is not None and ctx.needs_input_grad[2]:
grad_bias = grad_output.sum(0)
grad_bias_batch = grad_output * input.shape[0]
assert torch.allclose(grad_bias_batch.mean(dim=0), grad_bias)
# bias.grad_batch = grad_bias_batch
grad_bias = grad_bias_batch
return grad_input, grad_weight, grad_bias
class GradBatch_LinearModule(torch.nn.Module):
def __init__(self, input_features, output_features, bias=True):
super().__init__()
self.input_features = input_features
self.output_features = output_features
# nn.Parameter is a special kind of Tensor, that will get
# automatically registered as Module's parameter once it's assigned
# as an attribute. Parameters and buffers need to be registered, or
# they won't appear in .parameters() (doesn't apply to buffers), and
# won't be converted when e.g. .cuda() is called. You can use
# .register_buffer() to register buffers.
# nn.Parameters require gradients by default.
self.weight = Parameter(torch.Tensor(output_features, input_features))
if bias:
self.bias = Parameter(torch.Tensor(output_features))
else:
# You should always register all possible parameters, but the
# optional ones can be None if you want.
self.register_parameter('bias', None)
# Not a very smart way to initialize weights
self.weight.data.uniform_(-0.1, 0.1)
if self.bias is not None:
self.bias.data.uniform_(-0.1, 0.1)
def forward(self, input):
# See the autograd section for explanation of what happens here.
return GradBatch_LinearFunction.apply(input, self.weight, self.bias)
def extra_repr(self):
# (Optional)Set the extra information about this module. You can test
# it by printing an object of this class.
return 'input_features={}, output_features={}, bias={}'.format(
self.input_features, self.output_features, self.bias is not None
)
linear = GradBatch_LinearFunction.apply
from torch.autograd import gradcheck
# gradcheck takes a tuple of tensors as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (torch.randn(20, 20, dtype=torch.double, requires_grad=True), torch.randn(30, 20, dtype=torch.double, requires_grad=True))
# test = gradcheck(GradBatch_LinearFunction.apply, input, eps=1e-6, atol=1e-4)
linear = GradBatch_LinearModule(11, 12, True)
# print(f'linear weight: {id(linear.weight)}')
input = torch.randn(27,11)
output = linear(input)
output.sum().backward()
# print(linear.weight.grad.shape)
# print(f'linear weight: {id(linear.weight)=}')
# print(f'linear weight: {id(linear.weight.data)=}')
# print(linear.weight.grad.shape)
My question is therefore, how can I compute additional quantities in the backward pass (such as the per sample gradients) and store the conveniently very close to the parameters as an i.e. attribute of the Parameter object?