I’m currently using a GPT model implemented in PyTorch and I want to make some changes to the attention mechanism. Here’s my code (some parts are removed in order to make the code cleaner):
import math
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from .attentions import *
class GPTConfig:
""" base GPT config, params common to all GPT versions """
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
def __init__(self, vocab_size, block_size, **kwargs):
self.vocab_size = vocab_size
self.block_size = block_size
for k,v in kwargs.items():
setattr(self, k, v) # self.k = v, applicable to all objects (setattr("tony", age, 18))
class Block(nn.Module):
""" an unassuming Transformer block """
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
if config.attention == 'dot':
self.attn = CausalSelfDotAttention(config)
elif config.attention == 'mult':
self.attn = CausalSelfMultAttention(config)
elif config.attention == 'add':
self.attn = CausalSelfAddAttention(config)
elif config.attention == 'pos_enc':
self.attn = CausalSelfPosEncAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x)) # norm first
return x
class FusionBlock(nn.Module):
def __init__(self, config):
super().__init__()
config.attention = 'dot'
self.block1 = Block(config)
self.block2 = Block(config)
self.fusion = nn.Linear(2*config.n_embd, 4*config.n_embd)
self.activate = nn.GELU()
self.token = nn.Linear(4*config.n_embd, config.n_embd)
self.position = nn.Linear(4*config.n_embd, config.n_embd)
def forward(self, x, p):
x = self.block1(x)
p = self.block2(p)
fused = self.activate(self.fusion(torch.cat((x, p), dim=-1)))
x = self.activate(self.token(fused))
p = self.activate(self.position(fused))
return x, p
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
def __init__(self, config):
super().__init__()
# input embedding stem
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# attention type (to make possibles strucuture modifications)
self.attention = config.attention
# transformer
if self.attention == 'pos_fus':
self.blocks = nn.Sequential(*[FusionBlock(config) for _ in range(config.n_layer)])
else:
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.block_size = config.block_size
def forward(self, idx, targets=None):
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, model block size is exhausted." # no enough positional encodings
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector; (B, T, C)
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector; (1, T, C)
if self.attention == 'pos_fus':
token = self.drop(token_embeddings)
position = self.drop(position_embeddings)
x, _ = self.blocks(token, position)
else:
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x) # calculate logits for every vector in the sequence, but we only take the logits for the final vector when sampling (test time)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) # make logits to 2-dimensional, target to 1-dimensional
return logits, loss
And when I tried to run the code (using the “pos_fus” attention), it says:
TypeError: forward() takes 2 positional arguments but 3 were given
It seems (from the error message) that the error is from this line:
x, _ = self.blocks(token, position)
Could anyone help me please?