Hello,
First of all, thank you for reading my post !
In short, I am having trouble to obtain the same gradients when using .register_hook in the forward pass of a class, vs. using register_full_backward_hook outside of a class.
I recently started using hooks to obtain attention/ gradients related to Transformers. I am working on Vision-Language transformers (image & text modalities), such as the ALBEF model in my case (paper linked here but it is not necessary at all to answer my question: (https://proceedings.neurips.cc/paper_files/paper/2021/file/505259756244493872b7709a8a01b536-Paper.pdf)
I have to compare myself with a method, which obtained gradients in an Attention class in the forward function through attn.register_hook:
Here is the Attention class in question :
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, register_hook=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
As long as we can set the register_hook=True parameter in the forward pass, then I can obtain attention/gradients as I want on the Attention Class (with ‘self.save_attention_map(attn)’ and ‘attn.register_hook(self.save_attn_gradients)’ )
However, this parameter is not available in higher class levels (ALBEF class for example), and more importantly, the ALBEF class in my case is imported from a library, preventing me from adapting the code to obtain attention/gradients through the register_hook=True parameter in higher level classes.
Thus, I started to try to obtain attention and gradients outside of a class, using .register_forward_hook and .register_full_backward_hook (note that the difference with .register_backward_hook is not totally clear to me).
I am now trying to obtain the same results for attention and gradients retrieved through 1) .register_hook in the forward pass vs 2). .register_forward_hook and .register_full_backward_hook outside of the class.
I managed to retrieve the same attentions, but not gradients. I investigated and noted that when I am using only one attention block, gradients are the same when using .register_hook and .register_full_backward_hook, but as soon as I have more than one Attention block, it seems that gradients obtained through both means differ.
Here is a code that you can run, which provides diverging gradients. If you apply it on one block only (by setting depth=1 in the main function), the gradients obtained will be the same. Otherwise they will differ (depth is set to 2 by default in the code)
If you can help me understanding why I would very thankful, because I am missing something here regarding to hooks in Pytorch, despite looking at it a few times…
(I am running on Python 3.8.8 and Pytorch ‘1.12.1+cu102’, even though I doubt it will affect the results):
"""Test on gradients retrieved through 1) forward + tensor.register_hook and 2) hooks on module outside of class."""
import torch
import torch.nn as nn
# Attention Class (used in Block class)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attn_gradients = None
self.attention_map = None
def save_attn_gradients(self, attn_gradients):
self.attn_gradients = attn_gradients
def get_attn_gradients(self):
return self.attn_gradients
def save_attention_map(self, attention_map):
self.attention_map = attention_map
def get_attention_map(self):
return self.attention_map
def forward(self, x, register_hook=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
if register_hook:
self.save_attention_map(attn)
attn.register_hook(self.save_attn_gradients)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# MLP class (used in Block class)
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
# Block class (used in Intermediate class)
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, register_hook=False):
x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
# 'Intermediate' class to test if gradients obtained differ when using several blocks vs one (depending on the depth parameter!)
class Intermediate(nn.Module):
def __init__(self, dim, num_heads, depth=2):
super().__init__()
self.blocks = nn.ModuleList([
Block(dim=dim, num_heads=num_heads)
for i in range(depth)])
# self.norm = nn.LayerNorm()
def forward(self, x, register_hook=False):
for block in self.blocks:
x = block(x, register_hook=register_hook)
# x = self.norm(x)
return x
def are_tensors_equal(list1, list2):
"""Function testing if tensors from list of tensors are equal or not."""
if len(list1) != len(list2):
print("First Length: False")
else:
print("First Length; ok")
for tensor1, tensor2 in zip(list1, list2):
if not torch.equal(tensor1, tensor2):
print("Equal: False")
else:
print("Equal: True")
def __main__():
# Input Parameters
batch_size = 2
seq_length = 4
input_dim = 768
# Blk param
num_heads = 8
# Attn & Blk input
input_tensor = torch.randn(batch_size, seq_length, input_dim)
# Class instantiation
depth = 2 # !!! Change it to "1" and gradients will be the same.
inter = Intermediate(dim=input_dim, num_heads=num_heads, depth=depth)
# --------------------------------------
# --------------------------------------
# METHOD 2 : Use hooks outside of class
hook_list = []
grad_input_list = []
grad_output_list = []
att_list = []
def grad_hook_fn(module, grad_input, grad_output):
grad_input_list.append(grad_input[0].detach())
grad_output_list.append(grad_output[0].detach())
def att_hook_fn(module, att_input, att_output):
att_list.append(att_output.detach())
for i, blk in enumerate(inter.blocks):
hook_list.append(blk.attn.attn_drop.register_forward_hook(att_hook_fn))
hook_list.append(blk.attn.attn_drop.register_full_backward_hook(grad_hook_fn))
# --------------------------------------
# --------------------------------------
# Forward
output = inter.forward(input_tensor, register_hook=True)
# Loss & Backward
inter.zero_grad()
loss = torch.mean(output) # loss simulation
loss.backward(retain_graph=True)
# --------------------------------------
# --------------------------------------
# METHOD 1 : retrieve gradients and attention saved from using forward(register_hook=True)
atts_forward_list = []
grads_forward_list = []
for i, blk in enumerate(inter.blocks):
atts_forward_list.append(blk.attn.get_attention_map())
grads_forward_list.append(blk.attn.get_attn_gradients())
print("att:", atts_forward_list, "\n")
print("grads:", grads_forward_list)
# --------------------------------------
# --------------------------------------
print("Checking if both methods give the same results for Intermediate class with depth = ", depth, " !")
print("--- Equal: Att ---")
are_tensors_equal(atts_forward_list, att_list)
print("--- Equal: Forward (meth.1) vs Grad input (meth. 2) ---")
are_tensors_equal(grads_forward_list, grad_input_list)
print("--- Forward (meth. 1) vs Grad output (meth. 2) ---")
are_tensors_equal(grads_forward_list, grad_output_list)
# --- Attention ---
# print("atts_forward_list: ", atts_forward_list)
# print("att_list: ", att_list)
# --- Gradients ---
# print("grads_forward_list: ", grads_forward_list)
# print("grad_input_list: ", grad_input_list)
# print("grad_output_list: ", grad_output_list)
__main__()
Thanks for reading all of this !