Diverging gradients obtained through: 1) tensor.register_hook in a class vs 2) module.register_full_backward_hook outside of a class

Hello,

First of all, thank you for reading my post !

In short, I am having trouble to obtain the same gradients when using .register_hook in the forward pass of a class, vs. using register_full_backward_hook outside of a class.


I recently started using hooks to obtain attention/ gradients related to Transformers. I am working on Vision-Language transformers (image & text modalities), such as the ALBEF model in my case (paper linked here but it is not necessary at all to answer my question: (https://proceedings.neurips.cc/paper_files/paper/2021/file/505259756244493872b7709a8a01b536-Paper.pdf)

I have to compare myself with a method, which obtained gradients in an Attention class in the forward function through attn.register_hook:

Here is the Attention class in question :

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None
        
    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients
        
    def get_attn_gradients(self):
        return self.attn_gradients
    
    def save_attention_map(self, attention_map):
        self.attention_map = attention_map
        
    def get_attention_map(self):
        return self.attention_map
    
    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
                
        if register_hook:
            self.save_attention_map(attn)
            attn.register_hook(self.save_attn_gradients)        

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

As long as we can set the register_hook=True parameter in the forward pass, then I can obtain attention/gradients as I want on the Attention Class (with ‘self.save_attention_map(attn)’ and ‘attn.register_hook(self.save_attn_gradients)’ )

However, this parameter is not available in higher class levels (ALBEF class for example), and more importantly, the ALBEF class in my case is imported from a library, preventing me from adapting the code to obtain attention/gradients through the register_hook=True parameter in higher level classes.

Thus, I started to try to obtain attention and gradients outside of a class, using .register_forward_hook and .register_full_backward_hook (note that the difference with .register_backward_hook is not totally clear to me).

I am now trying to obtain the same results for attention and gradients retrieved through 1) .register_hook in the forward pass vs 2). .register_forward_hook and .register_full_backward_hook outside of the class.

I managed to retrieve the same attentions, but not gradients. I investigated and noted that when I am using only one attention block, gradients are the same when using .register_hook and .register_full_backward_hook, but as soon as I have more than one Attention block, it seems that gradients obtained through both means differ.


Here is a code that you can run, which provides diverging gradients. If you apply it on one block only (by setting depth=1 in the main function), the gradients obtained will be the same. Otherwise they will differ (depth is set to 2 by default in the code)

If you can help me understanding why I would very thankful, because I am missing something here regarding to hooks in Pytorch, despite looking at it a few times…

(I am running on Python 3.8.8 and Pytorch ‘1.12.1+cu102’, even though I doubt it will affect the results):

"""Test on gradients retrieved through 1) forward + tensor.register_hook and 2) hooks on module outside of class."""

import torch
import torch.nn as nn

# Attention Class (used in Block class)
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attn_gradients = None
        self.attention_map = None
        
    def save_attn_gradients(self, attn_gradients):
        self.attn_gradients = attn_gradients
        
    def get_attn_gradients(self):
        return self.attn_gradients
    
    def save_attention_map(self, attention_map):
        self.attention_map = attention_map
        
    def get_attention_map(self):
        return self.attention_map
    
    def forward(self, x, register_hook=False):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
                
        if register_hook:
            self.save_attention_map(attn)
            attn.register_hook(self.save_attn_gradients)        

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
    
# MLP class (used in Block class)
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
    
    
# Block class (used in Intermediate class)
class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        
    def forward(self, x, register_hook=False):
        x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x
    
    
# 'Intermediate' class to test if gradients obtained differ when using several blocks vs one (depending on the depth parameter!) 
class Intermediate(nn.Module):
    def __init__(self, dim, num_heads, depth=2):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim=dim, num_heads=num_heads)
            for i in range(depth)])
        # self.norm = nn.LayerNorm()
    
    def forward(self, x, register_hook=False):
        for block in self.blocks:
            x = block(x, register_hook=register_hook)
        # x = self.norm(x)
        return x

    
    
def are_tensors_equal(list1, list2):
    """Function testing if tensors from list of tensors are equal or not."""
    if len(list1) != len(list2):
        print("First Length: False")
    else:
        print("First Length; ok")
    
    for tensor1, tensor2 in zip(list1, list2):
        if not torch.equal(tensor1, tensor2):
            print("Equal: False")
        else:
            print("Equal: True")

            
def __main__():
    # Input Parameters
    batch_size = 2
    seq_length = 4
    input_dim = 768

    # Blk param
    num_heads = 8

    # Attn & Blk input
    input_tensor = torch.randn(batch_size, seq_length, input_dim)

    # Class instantiation
    
    depth = 2 # !!! Change it to "1" and gradients will be the same.
    
    inter = Intermediate(dim=input_dim, num_heads=num_heads, depth=depth)

    # --------------------------------------
    # --------------------------------------
    # METHOD 2 : Use hooks outside of class
    hook_list = []

    grad_input_list = []
    grad_output_list = []
    att_list = []

    def grad_hook_fn(module, grad_input, grad_output):
        grad_input_list.append(grad_input[0].detach())
        grad_output_list.append(grad_output[0].detach())

    def att_hook_fn(module, att_input, att_output):
        att_list.append(att_output.detach())


    for i, blk in enumerate(inter.blocks):
        hook_list.append(blk.attn.attn_drop.register_forward_hook(att_hook_fn))
        hook_list.append(blk.attn.attn_drop.register_full_backward_hook(grad_hook_fn))
    # --------------------------------------
    # --------------------------------------
        
        
        
    # Forward
    output = inter.forward(input_tensor, register_hook=True)
    
    # Loss & Backward
    inter.zero_grad()
    loss = torch.mean(output) # loss simulation
    loss.backward(retain_graph=True)
    
    
    
    # --------------------------------------
    # --------------------------------------
    # METHOD 1 : retrieve gradients and attention saved from using forward(register_hook=True)
    atts_forward_list = []
    grads_forward_list = []

    for i, blk in enumerate(inter.blocks):
        atts_forward_list.append(blk.attn.get_attention_map())
        grads_forward_list.append(blk.attn.get_attn_gradients())
    print("att:", atts_forward_list, "\n")
    print("grads:", grads_forward_list)
    # --------------------------------------
    # --------------------------------------
    
    
    print("Checking if both methods give the same results for Intermediate class with depth = ", depth, " !")
    
    print("--- Equal: Att ---")
    are_tensors_equal(atts_forward_list, att_list)
    print("--- Equal: Forward (meth.1) vs Grad input (meth. 2) ---")
    are_tensors_equal(grads_forward_list, grad_input_list)
    print("--- Forward (meth. 1) vs Grad output (meth. 2) ---")
    are_tensors_equal(grads_forward_list, grad_output_list)
    
    # --- Attention --- 
    
    # print("atts_forward_list: ", atts_forward_list)
    # print("att_list: ", att_list)
    
    # --- Gradients ---
    
    # print("grads_forward_list: ", grads_forward_list)
    # print("grad_input_list: ", grad_input_list)
    # print("grad_output_list: ", grad_output_list)
    
__main__()

Thanks for reading all of this :slight_smile: !

1 Like