The problem in short:
We have a model consisting of 3 layers (cross-attn, self-attn and ffn) and when we feed it with input and check its computation graph it appears that some specific layer from our model (fff.in_layer - Linear layer) is not present in the graph.
Code snippet with explanations:
from torch import nn
import torch
class RoiAlignDecoderLayer(nn.Module):
def __init__(self, n_channels, n_locs, n_heads, drop_rate: float = 0.0):
super().__init__()
self.cross_attn = RoiAlignedAttention(n_channels, n_locs, n_heads, drop_rate=drop_rate)
self.self_attn_norm = nn.LayerNorm(n_channels)
self.self_attn = SelfAttention(n_channels, n_heads, attn_drop=drop_rate, proj_drop=drop_rate)
self.ffn_norm = nn.LayerNorm(n_channels)
self.ffn = FFN(n_channels, n_channels, 4 * n_channels, drop_rate=drop_rate)
self.init_weights()
def init_weights(self):
nn.init.constant_(self.ffn.out_layer.weight, 0.0)
nn.init.constant_(self.ffn.out_layer.bias, 0.0)
def forward(self, queries, aligned_features):
queries = queries + self.cross_attn(queries, aligned_features)
queries = queries + self.self_attn(self.self_attn_norm(queries))
queries = queries + self.ffn(self.ffn_norm(queries))
return queries
class RoiAlignedAttention(nn.Module):
def __init__(self, n_channels, n_locs, n_heads=8, drop_rate: float = 0.0):
super().__init__()
self.n_locs = n_locs
self.n_heads = n_heads
self.attn_weights = nn.Linear(n_channels, n_heads * n_locs)
self.attn_dropout = nn.Dropout(drop_rate)
self.projection = nn.Linear(n_channels, n_channels)
self.proj_dropout = nn.Dropout(drop_rate)
def forward(self, queries: torch.Tensor, aligned_features: torch.Tensor):
aligned_features = einops.rearrange(aligned_features, "B N (nh Ch) n_locs -> B N nh Ch n_locs", nh=self.n_heads)
attn_weights = self.attn_weights(queries)
attn_weights = einops.rearrange(attn_weights, "B N (nh n_locs) -> B N nh (n_locs)", nh=self.n_heads).softmax(
dim=-1
)
attn_weights = self.attn_dropout(attn_weights)
outputs = einops.einsum(
aligned_features,
attn_weights,
"B N nh Ch n_locs, B N nh n_locs -> B N nh Ch",
)
outputs = self.projection(outputs.flatten(-2))
outputs = self.proj_dropout(outputs)
return outputs
class SelfAttention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
norm_layer=nn.LayerNorm,
):
super().__init__()
assert dim % num_heads == 0, "dim should be divisible by num_heads"
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
# x = F.scaled_dot_product_attention(
# q, k, v,
# dropout_p=self.attn_drop.p,
# )
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class FFN(nn.Module):
def __init__(self, in_c: int, out_c: int, mid_c: int, drop_rate: float = 0.0):
super().__init__()
self.in_layer = nn.Linear(in_c, mid_c)
self.act = nn.ReLU(inplace=True)
self.dropout = nn.Dropout(drop_rate, inplace=True)
self.out_layer = nn.Linear(mid_c, out_c)
def forward(self, x):
x = self.in_layer(x)
x = self.dropout(x)
x = self.act(x)
x = self.out_layer(x)
return x
class SelfAttentionPruner(tp.function.BasePruningFunc):
TARGET_MODULES = SelfAttention
def check(self, layer, idxs, to_output):
super().check(layer, idxs, to_output)
assert (
layer.dim - len(idxs)
) % layer.num_heads == 0, "dim (%d) of MultiheadAttention after pruning must divide evenly by `num_heads` (%d)" % (
layer.embed_dim,
layer.num_heads,
)
def prune_out_channels(self, layer, idxs: list) -> nn.Module:
dim = self.get_out_channels(layer)
keep_idxs = list(set(range(dim)) - set(idxs))
keep_idxs.sort()
pruning_idxs_repeated = idxs + [i + dim for i in idxs] + [i + 2 * dim for i in idxs]
keep_idxs_3x_repeated = list(set(range(3 * dim)) - set(pruning_idxs_repeated))
layer.qkv.weight = self._prune_parameter_and_grad(layer.qkv.weight, keep_idxs, 1)
layer.qkv.weight = self._prune_parameter_and_grad(layer.qkv.weight, keep_idxs_3x_repeated, 0)
layer.qkv.in_features = len(keep_idxs)
layer.qkv.out_features = len(keep_idxs_3x_repeated)
layer.proj.weight = self._prune_parameter_and_grad(layer.proj.weight, keep_idxs, 1)
layer.proj.weight = self._prune_parameter_and_grad(layer.proj.weight, keep_idxs, 0)
layer.proj.bias = self._prune_parameter_and_grad(layer.proj.bias, keep_idxs, 0)
layer.proj.in_features = layer.proj.out_features = len(keep_idxs)
assert (layer.head_dim - len(idxs)) % layer.num_heads == 0
layer.head_dim = len(keep_idxs) // layer.num_heads
print("PRUNING SelfAttention...")
return layer
prune_in_channels = prune_out_channels
def get_out_channels(self, layer):
return layer.head_dim * layer.num_heads
def get_in_channels(self, layer):
return self.get_out_channels(layer)
class RoiAlignedAttentionPruner(tp.function.BasePruningFunc):
TARGET_MODULES = RoiAlignedAttention
def check(self, layer, idxs, to_output):
super().check(layer, idxs, to_output)
def prune_out_channels(self, layer, idxs: list) -> nn.Module:
dim = self.get_out_channels(layer)
keep_idxs = list(set(range(dim)) - set(idxs))
keep_idxs.sort()
layer.attn_weights.weight = self._prune_parameter_and_grad(layer.attn_weights.weight, keep_idxs, 1)
layer.attn_weights.in_features = len(keep_idxs)
layer.projection.weight = self._prune_parameter_and_grad(layer.projection.weight, keep_idxs, 1)
layer.projection.weight = self._prune_parameter_and_grad(layer.projection.weight, keep_idxs, 0)
layer.projection.bias = self._prune_parameter_and_grad(layer.projection.bias, keep_idxs, 0)
layer.projection.in_features = layer.projection.out_features = len(keep_idxs)
print("PRUNING RoiAlignedAttention...")
return layer
prune_in_channels = prune_out_channels
def get_out_channels(self, layer):
return layer.attn_weights.in_features
def get_in_channels(self, layer):
return self.get_out_channels(layer)
radl = RoiAlignDecoderLayer(64, 75, 8, 0.1)
queries = torch.randn(*[1, 64, 64])
aligned_features = torch.randn(*[1, 64, 64, 75])
example_inputs = [queries, aligned_features]
print(radl)
let’s do forward pass and assign grad_fn for according layers (we’re interested in radl.ffn.in_layer):
radl.eval()
gradfn2module = {}
def _record_grad_fn(module, inputs, outputs):
print('module', type(module))
gradfn2module[outputs.grad_fn] = module
hooks = [
m.register_forward_hook(_record_grad_fn)
for m in radl.modules()
if isinstance(m, nn.Linear)
]
out = radl(*example_inputs)
for hook in hooks:
hook.remove()
slicky_grad_fn = None
module2gradfn = {value: key for key, value in gradfn2module.items()}
for m, g in module2gradfn.items():
if (
isinstance(m, nn.Linear)
and (m.in_features == 64)
and (m.out_features == 256)
):
print("SLICKY LAYER and ITS GRAD_FN", m, g, g.next_functions)
if slicky_grad_fn is None:
slicky_grad_fn = g
now let’s take output and propogate it back using grad_fn property looking for the slicky grad_fn:
processing_stack = [out.grad_fn]
while len(processing_stack) > 0:
grad_fn = processing_stack.pop(-1)
# print(grad_fn)
if grad_fn is slicky_grad_fn:
raise Exception('hooray!! the slicky grad_fn has been found')
if hasattr(grad_fn, "next_functions"):
for f in grad_fn.next_functions:
processing_stack.append(f[0])
as a result we don’t find the grad_fn corresponding to the radl.ffn.in_layer but for other layers it does work.
what could be the problem that the layer isn’t present in the computation graph?
Thanks in advance.