Unexpected Behavior with Weight Sharing between nn.Linear and nn.Embedding

Unexpected Behavior with Weight Sharing between nn.Linear and nn.Embedding in GPT-2 Implementation

Problem Description

I’ve encountered an unexpected behavior when sharing weights between nn.Linear and nn.Embedding layers in my GPT-2 implementation. The order of weight sharing assignment affects the model’s sampling behavior, even with random initialization.

Code

class MLP(nn.Module):
    def __init__(self, d_model: int, activation: Literal['gelu', 'relu'], hidden_dim, **kwargs):
        super().__init__(**kwargs)

        self.d_model = d_model
        self.hidden_dim = hidden_dim

        self.in_layer = nn.LazyLinear(hidden_dim, **kwargs)
        self.activation = nn.GELU(approximate='tanh') if activation == "gelu" else nn.ReLU()
        self.out_layer = nn.LazyLinear(d_model, **kwargs)

    def forward(self, x: torch.Tensor):
        """
        Args:
        :param x: Tensor, shape [batch_size, seq_len, d_model]
        :return: Tensor, shape [batch_size, seq_len, d_model]
        """
        x = self.in_layer(x)
        x = self.out_layer(self.activation(x))
        return x


class MultiHeadedAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, **kwargs):
        super().__init__(**kwargs)
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        self.qkv_projection = nn.LazyLinear(3 * d_model, **kwargs)
        self.out_projection = nn.LazyLinear(d_model, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, S, E = x.size()

        query, key, value = self.qkv_projection(x).chunk(3, dim=-1)
        query = query.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)
        key = key.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)
        value = value.view(B, S, self.num_heads, E // self.num_heads).transpose(1, 2)

       
        out = F.scaled_dot_product_attention(query, key, value, is_causal=True)
        out = out.transpose(1, 2).contiguous().view(B, S, E)
        out = self.out_projection(out)

        return out


class Block(nn.Module):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.d_model = config.d_model
        self.num_heads = config.num_heads
        self.activation = config.activation
        self.hidden_dim = config.hidden_dim

        self.layer_norm1 = nn.LayerNorm(self.d_model, **kwargs)
        self.attention = MultiHeadedAttention(self.d_model, self.num_heads, **kwargs)
        self.layer_norm2 = nn.LayerNorm(self.d_model, **kwargs)
        self.mlp = MLP(self.d_model, self.activation, self.hidden_dim, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        :param x: torch.Tensor [batch_size, seq_len, d_model]
        :return: torch.Tensor [batch_size, seq_len, d_model]
        """
        x = x + self.attention(self.layer_norm1(x))
        x = x + self.mlp(self.layer_norm2(x))
        return x


class GPT2(nn.Module):
    def __init__(self, config: GPTConfig, **kwargs):
        super().__init__(**kwargs)
        self.config = config
        self.model = nn.ModuleDict(dict(
            tok_embedding=nn.Embedding(config.vocab_size, config.d_model, **kwargs),
            pos_embedding=nn.Embedding(config.max_seq_len, config.d_model, **kwargs),
            blocks=nn.ModuleList([Block(config, **kwargs) for _ in range(config.num_layers)]),
            final_norm=nn.LayerNorm(config.d_model, **kwargs),
        ))
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, **kwargs, bias=False)  # as in gpt2
        # Version A (produces degenerate outputs):
        self.lm_head.weight = self.model.tok_embedding.weight

        # Version B (works correctly)
        self.model.tok_embedding.weight = self.lm_head.weight

    def forward(self, idx: torch.Tensor, targets: torch.Tensor = None):
        """
        :param idx: torch.Tensor [batch_size, seq_len]
        :param targets torch.Tensor [batch_size, seq_len]
        :return: torch.Tensor [batch_size, seq_len, d_model]
        """
        B, T = idx.size()
        pos = torch.arange(0, T, device=idx.device, dtype=torch.long)
        x = self.model.pos_embedding(pos) + self.model.tok_embedding(idx)
        for block in self.model.blocks:
            x = block(x)

        x = self.model.final_norm(x)
        logits = self.lm_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return logits, loss

    ```

## Inference Code and Observed Behavior
```python
model = GPT2(GPTConfig())
model.eval()
model.to(device)

max_length = 30
seq_to_sample = 5

tokenizer = tiktoken.get_encoding('gpt2')
tokens = tokenizer.encode("Hello, I'm a language model,")
tokens = torch.tensor(tokens, dtype=torch.long, device=device)
tokens = tokens.unsqueeze(0).repeat(seq_to_sample, 1)

outputs = model.sampling_loop(tokens, max_length)

for i in range(seq_to_sample):
    out_tokens = outputs[i, :max_length].tolist()
    decoded = tokenizer.decode(out_tokens)
    print(f">{decoded}")

Output with Version A (buggy)

>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,
>Hello, I'm a language model,,,,,,,,,,,,,,,,,,,,,,,

The model with Version A (where self.lm_head.weight = self.model.tok_embedding.weight) produces degenerate outputs, repeatedly sampling the same token (comma). However, when using Version B (self.model.tok_embedding.weight = self.lm_head.weight), the model produces more diverse and sensible continuations.

Why does the order of weight sharing assignment matter in this case, even with random initialization?

Your model is missing the forward function and should thus directly fail with:

    raise NotImplementedError(

NotImplementedError: Module [GPT2] is missing the required "forward" function

since self(x) will call into self._call_impl(*args, **kwargs) and forward_call(*args, **kwargs).

1 Like

I omitted it, because it’s long. Now it’s there.