I am trying to monkey-patch the forward()
method of an nn.Module
. Hereโs my nn.Module
:
import torch.nn as nn
class GPT5(nn.Module):
embed_dim = 768
num_heads = 12
q_proj = nn.Linear(embed_dim, embed_dim)
head_dim = embed_dim // num_heads
scale = head_dim**-0.5
def forward(self, hidden_states):
return self.q_proj(hidden_states) * self.scale
The following works as usual:
import torch
gpt5 = GPT5()
gpt5(torch.randn(1, 10, 768)).size()
Do monkey-patching:
gpt5 = GPT5()
new_forward = lambda x: l.forward(x) + 1
gpt5.forward = new_forward
The following then raises an error:
gpt5(torch.randn(1, 10, 768)).size()
โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ Traceback (most recent call last) โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโฎ
โ in <module>:1 โ
โ โ
โ /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1194 in _call_impl โ
โ โ
โ 1191 โ โ # this function, and just call forward. โ
โ 1192 โ โ if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o โ
โ 1193 โ โ โ โ or _global_forward_hooks or _global_forward_pre_hooks): โ
โ โฑ 1194 โ โ โ return forward_call(*input, **kwargs) โ
โ 1195 โ โ # Do not call functions when jit is used โ
โ 1196 โ โ full_backward_hooks, non_full_backward_hooks = [], [] โ
โ 1197 โ โ if self._backward_hooks or _global_backward_hooks: โ
โ in <lambda>:2 โ
โ in <lambda>:2 โ