Unexpected Behavior with Weight Sharing between nn.Linear and nn.Embedding in GPT-2 Implementation
Problem Description
I’ve encountered an unexpected behavior when sharing weights between nn.Linear
and nn.Embedding
layers in my GPT-2 implementation. The order of weight sharing assignment affects the model’s sampling behavior, even with random initialization.
Code
class MLP(nn.Module):
def __init__(self, d_model: int, activation: Literal['gelu', 'relu'], hidden_dim, **kwargs):
super().__init__(**kwargs)
self.d_model = d_model
self.hidden_dim = hidden_dim
self.in_layer = nn.LazyLinear(hidden_dim, **kwargs)
self.activation = nn.GELU(approximate='tanh') if activation == "gelu" else nn.ReLU()
self.out_layer = nn.LazyLinear(d_model, **kwargs)
def forward(self, x: torch.Tensor):
"""
Args:
:param x: Tensor, shape [batch_size, seq_len, d_model]
:return: Tensor, shape [batch_size, seq_len, d_model]
"""
x = self.in_layer(x)
x = self.out_layer(self.activation(x))
return x
class MultiHeadedAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, **kwargs):
super().__init__(**kwargs)
assert d_model % num_heads == 0
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads
self.qkv_projection = nn.LazyLinear(3 * d_model, **kwargs)
self.out_projection = nn.LazyLinear(d_model, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, E = x.size()
query, key, value = self.qkv_projection(x).chunk(3, dim=-1)
query = query.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)
key = key.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)
value = value.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)
out = F.scaled_dot_product_attention(query, key, value, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, S, E)
out = self.out_projection(out)
return out
class Block(nn.Module):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.d_model = config.d_model
self.num_heads = config.num_heads
self.activation = config.activation
self.hidden_dim = config.hidden_dim
self.layer_norm1 = nn.LayerNorm(self.d_model, **kwargs)
self.attention = MultiHeadedAttention(self.d_model, self.num_heads, **kwargs)
self.layer_norm2 = nn.LayerNorm(self.d_model, **kwargs)
self.mlp = MLP(self.d_model, self.activation, self.hidden_dim, **kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
:param x: torch.Tensor [batch_size, seq_len, d_model]
:return: torch.Tensor [batch_size, seq_len, d_model]
"""
x = x + self.attention(self.layer_norm1(x))
x = x + self.mlp(self.layer_norm2(x))
return x
class GPT2(nn.Module):
def __init__(self, config: GPTConfig, **kwargs):
super().__init__(**kwargs)
self.config = config
self.model = nn.ModuleDict(dict(
tok_embedding=nn.Embedding(config.vocab_size, config.d_model, **kwargs),
pos_embedding=nn.Embedding(config.max_seq_len, config.d_model, **kwargs),
blocks=nn.ModuleList([Block(config, **kwargs) for _ in range(config.num_layers)]),
final_norm=nn.LayerNorm(config.d_model, **kwargs),
))
self.lm_head = nn.Linear(config.d_model, config.vocab_size, **kwargs, bias=False) # as in gpt2
# Version A (produces degenerate outputs):
self.lm_head.weight = self.model.tok_embedding.weight
# Version B (works correctly)
self.model.tok_embedding.weight = self.lm_head.weight
def forward(self, idx: torch.Tensor, targets: torch.Tensor = None):
"""
:param idx: torch.Tensor [batch_size, seq_len]
:param targets torch.Tensor [batch_size, seq_len]
:return: torch.Tensor [batch_size, seq_len, d_model]
"""
B, T = idx.size()
pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
x = self.model.pos_embedding(pos) + self.model.tok_embedding(idx)
for block in self.model.blocks:
x = block(x)
x = self.model.final_norm(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss
```
## Inference Code and Observed Behavior
```python
model = GPT2(GPTConfig())
model.eval()
model.to(device)
max_length = 30
seq_to_sample = 5
tokenizer = tiktoken.get_encoding('gpt2')
tokens = tokenizer.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
tokens = tokens.unsqueeze(0).repeat(seq_to_sample, 1)
outputs = model.sampling_loop(tokens, max_length)
for i in range(seq_to_sample):
out_tokens = outputs[i, :max_length].tolist()
decoded = tokenizer.decode(out_tokens)
print(f">{decoded}")
Output with Version A (buggy)
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
The model with Version A (where self.lm_head.weight = self.model.tok_embedding.weight
) produces degenerate outputs, repeatedly sampling the same token (comma). However, when using Version B (self.model.tok_embedding.weight = self.lm_head.weight
), the model produces more diverse and sensible continuations.
Why does the order of weight sharing assignment matter in this case, even with random initialization?