No grad_fn in computation graph for specific layer after doing forward call

The problem in short:
We have a model consisting of 3 layers (cross-attn, self-attn and ffn) and when we feed it with input and check its computation graph it appears that some specific layer from our model (fff.in_layer - Linear layer) is not present in the graph.

Code snippet with explanations:

from torch import nn
import torch

class RoiAlignDecoderLayer(nn.Module):
    def __init__(self, n_channels, n_locs, n_heads, drop_rate: float = 0.0):

        self.cross_attn = RoiAlignedAttention(n_channels, n_locs, n_heads, drop_rate=drop_rate)

        self.self_attn_norm = nn.LayerNorm(n_channels)
        self.self_attn = SelfAttention(n_channels, n_heads, attn_drop=drop_rate, proj_drop=drop_rate)

        self.ffn_norm = nn.LayerNorm(n_channels)
        self.ffn = FFN(n_channels, n_channels, 4 * n_channels, drop_rate=drop_rate)


    def init_weights(self):
        nn.init.constant_(self.ffn.out_layer.weight, 0.0)
        nn.init.constant_(self.ffn.out_layer.bias, 0.0)

    def forward(self, queries, aligned_features):
        queries = queries + self.cross_attn(queries, aligned_features)
        queries = queries + self.self_attn(self.self_attn_norm(queries))
        queries = queries + self.ffn(self.ffn_norm(queries))
        return queries

class RoiAlignedAttention(nn.Module):
    def __init__(self, n_channels, n_locs, n_heads=8, drop_rate: float = 0.0):
        self.n_locs = n_locs
        self.n_heads = n_heads

        self.attn_weights = nn.Linear(n_channels, n_heads * n_locs)
        self.attn_dropout = nn.Dropout(drop_rate)
        self.projection = nn.Linear(n_channels, n_channels)
        self.proj_dropout = nn.Dropout(drop_rate)

    def forward(self, queries: torch.Tensor, aligned_features: torch.Tensor):
        aligned_features = einops.rearrange(aligned_features, "B N (nh Ch) n_locs -> B N nh Ch n_locs", nh=self.n_heads)
        attn_weights = self.attn_weights(queries)
        attn_weights = einops.rearrange(attn_weights, "B N (nh n_locs) -> B N nh (n_locs)", nh=self.n_heads).softmax(
        attn_weights = self.attn_dropout(attn_weights)

        outputs = einops.einsum(
            "B N nh Ch n_locs, B N nh n_locs -> B N nh Ch",
        outputs = self.projection(outputs.flatten(-2))
        outputs = self.proj_dropout(outputs)
        return outputs

class SelfAttention(nn.Module):
    def __init__(
        assert dim % num_heads == 0, "dim should be divisible by num_heads"
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        q, k = self.q_norm(q), self.k_norm(k)

        # x = F.scaled_dot_product_attention(
        #    q, k, v,
        #    dropout_p=self.attn_drop.p,
        # )
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class FFN(nn.Module):
    def __init__(self, in_c: int, out_c: int, mid_c: int, drop_rate: float = 0.0):
        self.in_layer = nn.Linear(in_c, mid_c)
        self.act = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(drop_rate, inplace=True)
        self.out_layer = nn.Linear(mid_c, out_c)

    def forward(self, x):
        x = self.in_layer(x)
        x = self.dropout(x)
        x = self.act(x)
        x = self.out_layer(x)
        return x

class SelfAttentionPruner(tp.function.BasePruningFunc):
    TARGET_MODULES = SelfAttention

    def check(self, layer, idxs, to_output):
        super().check(layer, idxs, to_output)
        assert (
            layer.dim - len(idxs)
        ) % layer.num_heads == 0, "dim (%d) of MultiheadAttention after pruning must divide evenly by `num_heads` (%d)" % (

    def prune_out_channels(self, layer, idxs: list) -> nn.Module:
        dim = self.get_out_channels(layer)

        keep_idxs = list(set(range(dim)) - set(idxs))

        pruning_idxs_repeated = idxs + [i + dim for i in idxs] + [i + 2 * dim for i in idxs]

        keep_idxs_3x_repeated = list(set(range(3 * dim)) - set(pruning_idxs_repeated))

        layer.qkv.weight = self._prune_parameter_and_grad(layer.qkv.weight, keep_idxs, 1)
        layer.qkv.weight = self._prune_parameter_and_grad(layer.qkv.weight, keep_idxs_3x_repeated, 0)
        layer.qkv.in_features = len(keep_idxs)
        layer.qkv.out_features = len(keep_idxs_3x_repeated)

        layer.proj.weight = self._prune_parameter_and_grad(layer.proj.weight, keep_idxs, 1)
        layer.proj.weight = self._prune_parameter_and_grad(layer.proj.weight, keep_idxs, 0)
        layer.proj.bias = self._prune_parameter_and_grad(layer.proj.bias, keep_idxs, 0)

        layer.proj.in_features = layer.proj.out_features = len(keep_idxs)
        assert (layer.head_dim - len(idxs)) % layer.num_heads == 0
        layer.head_dim = len(keep_idxs) // layer.num_heads
        print("PRUNING SelfAttention...")
        return layer

    prune_in_channels = prune_out_channels

    def get_out_channels(self, layer):
        return layer.head_dim * layer.num_heads

    def get_in_channels(self, layer):
        return self.get_out_channels(layer)

class RoiAlignedAttentionPruner(tp.function.BasePruningFunc):
    TARGET_MODULES = RoiAlignedAttention

    def check(self, layer, idxs, to_output):
        super().check(layer, idxs, to_output)

    def prune_out_channels(self, layer, idxs: list) -> nn.Module:
        dim = self.get_out_channels(layer)

        keep_idxs = list(set(range(dim)) - set(idxs))
        layer.attn_weights.weight = self._prune_parameter_and_grad(layer.attn_weights.weight, keep_idxs, 1)
        layer.attn_weights.in_features = len(keep_idxs)

        layer.projection.weight = self._prune_parameter_and_grad(layer.projection.weight, keep_idxs, 1)
        layer.projection.weight = self._prune_parameter_and_grad(layer.projection.weight, keep_idxs, 0)
        layer.projection.bias = self._prune_parameter_and_grad(layer.projection.bias, keep_idxs, 0)
        layer.projection.in_features = layer.projection.out_features = len(keep_idxs)

        print("PRUNING RoiAlignedAttention...")
        return layer

    prune_in_channels = prune_out_channels

    def get_out_channels(self, layer):
        return layer.attn_weights.in_features

    def get_in_channels(self, layer):
        return self.get_out_channels(layer)

radl = RoiAlignDecoderLayer(64, 75, 8, 0.1)
queries = torch.randn(*[1, 64, 64])
aligned_features = torch.randn(*[1, 64, 64, 75])
example_inputs = [queries, aligned_features]

let’s do forward pass and assign grad_fn for according layers (we’re interested in radl.ffn.in_layer):

gradfn2module = {}

def _record_grad_fn(module, inputs, outputs):
    print('module', type(module))
    gradfn2module[outputs.grad_fn] = module

hooks = [
    for m in radl.modules()
    if isinstance(m, nn.Linear)

out = radl(*example_inputs)

for hook in hooks:
slicky_grad_fn = None
module2gradfn = {value: key for key, value in gradfn2module.items()}
for m, g in module2gradfn.items():
    if (
        isinstance(m, nn.Linear)
        and (m.in_features == 64)
        and (m.out_features == 256)
        print("SLICKY LAYER and ITS GRAD_FN", m, g, g.next_functions)
        if slicky_grad_fn is None:
            slicky_grad_fn = g

now let’s take output and propogate it back using grad_fn property looking for the slicky grad_fn:

processing_stack = [out.grad_fn]

while len(processing_stack) > 0:
    grad_fn = processing_stack.pop(-1)
    # print(grad_fn)
    if grad_fn is slicky_grad_fn:
        raise Exception('hooray!! the slicky grad_fn has been found')
    if hasattr(grad_fn, "next_functions"):
        for f in grad_fn.next_functions:

as a result we don’t find the grad_fn corresponding to the radl.ffn.in_layer but for other layers it does work.
what could be the problem that the layer isn’t present in the computation graph?
Thanks in advance.

Without looking too closely, is it possible that you are comparing a different nn.Linear module’s grad_fn directly using is when the intent is to check the type of the grad_fn instead? In that case you might want to use isinstance rather than is.

Consider the following toy example:

>>> a = torch.randn(10, requires_grad=True)
>>> b = torch.randn(2, requires_grad=True)
>>> o = a.sum()
>>> o2 = b.sum()
>>> o2.grad_fn is o.grad_fn
>>> isinstance(o2.grad_fn, type(o.grad_fn))
1 Like

thanks for your answer.
actually not really… I need to make sure that it’s the same specific grad_fn function object not the same type, cause in the latter case I’ll get a bunch of grad_fn functions producing another tensors but having the same type.
In my case I do compare all the graph’s grad_fn just to underline the point that the grad_fn linked to my target layer radl.ffn.in_layer is not present in the graph.

any ideas what direction should I look at at least?