Monkey-patching the `forward()` pass of an `nn.Module`

I am trying to monkey-patch the forward() method of an nn.Module. Hereโ€™s my nn.Module:

import torch.nn as nn 

class GPT5(nn.Module):
    embed_dim = 768
    num_heads = 12
    q_proj = nn.Linear(embed_dim, embed_dim)
    head_dim = embed_dim // num_heads
    scale = head_dim**-0.5

    def forward(self, hidden_states):
        return self.q_proj(hidden_states) * self.scale

The following works as usual:

import torch 

gpt5 = GPT5()
gpt5(torch.randn(1, 10, 768)).size()

Do monkey-patching:

gpt5 = GPT5()
new_forward = lambda x: l.forward(x) + 1
gpt5.forward = new_forward

The following then raises an error:

gpt5(torch.randn(1, 10, 768)).size()
โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Traceback (most recent call last) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ
โ”‚ in <module>:1                                                                                    โ”‚
โ”‚                                                                                                  โ”‚
โ”‚ /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1194 in _call_impl             โ”‚
โ”‚                                                                                                  โ”‚
โ”‚   1191 โ”‚   โ”‚   # this function, and just call forward.                                           โ”‚
โ”‚   1192 โ”‚   โ”‚   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  โ”‚
โ”‚   1193 โ”‚   โ”‚   โ”‚   โ”‚   or _global_forward_hooks or _global_forward_pre_hooks):                   โ”‚
โ”‚ โฑ 1194 โ”‚   โ”‚   โ”‚   return forward_call(*input, **kwargs)                                         โ”‚
โ”‚   1195 โ”‚   โ”‚   # Do not call functions when jit is used                                          โ”‚
โ”‚   1196 โ”‚   โ”‚   full_backward_hooks, non_full_backward_hooks = [], []                             โ”‚
โ”‚   1197 โ”‚   โ”‚   if self._backward_hooks or _global_backward_hooks:                                โ”‚
โ”‚ in <lambda>:2                                                                                    โ”‚
โ”‚ in <lambda>:2                                                                                    โ”‚

Does the following work for you?

class GPT5(nn.Module):
    embed_dim = 768
    num_heads = 12
    q_proj = nn.Linear(embed_dim, embed_dim)
    head_dim = embed_dim // num_heads
    scale = head_dim**-0.5

    def forward(self, hidden_states):
        return self.q_proj(hidden_states) * self.scale

gpt5 = GPT5()
old_forward = gpt5.forward

new_forward = lambda x: old_forward(x) + 1
gpt5.forward = new_forward

gpt5(torch.randn(1, 10, 768))

1 Like

Thanks! It works. Couldnโ€™t realize the fix would be about separately storing the forward in a variable.