Replacing the LlamaDecoderLayer Class hugging Face With New LongNet

Hey i hope you are doing Great this weekend

i would like to ask you Please a Technical Question !!

i working on the CodeLLama Model which Uses a Decoder-Only Model Transformer following Arch Blow

Main Task is replaced Decoder-Only which used Masked-Self-Attention and KV_cache with my own Encoder-Only which used Diltaed-Attention used in LongNet

here the code Based on

from transformers import AutoTokenizer, AutoModelForCausalLM
import transformers
import torch

model_id = "codellama/CodeLlama-7b-hf"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16
).to("cpu")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32016, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32016, bias=False)
)

I planned to Replace the Block of LlamaDecoderLayer following within Encoder-only here the Origin Block Decoder-Only used in CodeLlama:

(0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()

with my own using Inherent from base Class From Hugging Face Here my Following Process i did to Replace with Encoder-only

Step 1 : Inherent From LlamaConfig To adjust the new parameters config used in my own Encoder model which used Dilated Multi-heads Attention

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention , LlamaDecoderLayer , LlamaModel, LlamaForCausalLM
class CondensedLlamaConfig(LlamaConfig):
    def __init__(
        self,
        dilation_rates=None,
        segment_lengths=None,
        is_causal=None,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.dilation_rates = dilation_rates
        self.segment_lengths = segment_lengths
        self.is_causal = is_causal

    # Override the `to_dict` method to include the new parameters
    def to_dict(self):
        base_dict = super().to_dict()
        config_dict = {
            "dilation_rates": self.dilation_rates,
            "segment_lengths": self.segment_lengths,
            "is_causal": self.is_causal
        }
        base_dict.update(config_dict)
        return base_dict

Output :

CondensedLlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "dilation_rates": [
    2048,
    4096,
    8192,
    16384,
    32768
  ],
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "is_causal": false,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "segment_lengths": [
    1,
    2,
    4,
    6,
    12
  ],
  "tie_word_embeddings": false,
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 32000
}

Step 2 : the only part i wanted to Replace is self_attn and my own Multi-head-Dilaed Attention is following is LongNet based Mechanism following code Blow

Here the Dilated Attention used flash_Attention_2 is Optional based on GPU used arch support A100 or T4 GPU

from typing import Callable, List, NamedTuple, Optional, Sequence, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.nn.modules.transformer import _get_activation_fn
import logging
import os
from functools import partial
from math import ceil
from timeit import Timer
from einops import rearrange
import plotly.graph_objects as go
import xformers.ops as xops
class DilatedAttention(nn.Module):
    """
    DilatedAttention module implements dilated, scaled dot product attention with softmax.

    Args:
        segment_lengths (Sequence[int]): Lengths of segments for attention.
        dilation_rates (Sequence[int]): Dilation rates for attention.
        softmax_scale (Optional[float]): Temperature for softmax attention. Default is None.
        attention_dropout (float): Dropout rate for attention. Default is 0.0.
        op (Optional[xops.AttentionOp]): Attention operation. Default is None.
    """

    def __init__(
        self,
        segment_lengths: Sequence[int],
        dilation_rates: Sequence[int],
        softmax_scale: Optional[float] = None,
        attention_dropout: float = 0.0,
        op: Optional[xops.AttentionOp] = None,
    ):
        super().__init__()
        if len(segment_lengths) != len(dilation_rates):
            raise ValueError("segment_lengths and dilation_rates must have the same length")

        self.segment_lengths = segment_lengths
        self.dilation_rates = dilation_rates
        self.softmax_scale = softmax_scale
        self.dropout_p = attention_dropout
        self.op = op

    def forward(self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False) -> Tensor:
        """
        Forward pass of the DilatedAttention module.

        Args:
            query (Tensor): Query tensor.
            key (Tensor): Key tensor.
            value (Tensor): Value tensor.
            is_causal (bool): Flag indicating if the attention is causal. Default is False.

        Returns:
            Tensor: Output tensor.
        """
        out = torch.zeros_like(query)
        num_groups = len(self.dilation_rates)
        group_sizes = [query.size(2) // num_groups] * num_groups
        for i, (g, r, s) in enumerate(zip(group_sizes, self.dilation_rates, self.segment_lengths)):
            q = rearrange(query, "b n h d -> b n h d")
            k = rearrange(key, "b n h d -> b n h d")
            v = rearrange(value, "b n h d -> b n h d")
            attn_bias = xops.LowerTriangularMask() if is_causal else None
            x = xops.memory_efficient_attention(
                query=q, key=k, value=v, op=self.op, attn_bias=attn_bias
            )
            out += x
        return out / num_groups

Here The Multi-head Dilated Attention

class MultiheadDilatedAttention(nn.Module):
    """
    MultiheadDilatedAttention module implements a multi-head dilated attention mechanism.

    Args:
        embed_dim (int): The dimension of the input embeddings.
        num_heads (int): Number of attention heads.
        dilation_rates (Sequence[int]): Dilation rates for attention.
        segment_lengths (Sequence[int]): Lengths of segments for attention.
        dropout (float): Dropout rate for attention. Default is 0.0.
        bias (bool): If True, enables bias in linear projections. Default is True.
        layer_norm (bool): If True, applies layer normalization. Default is True.
        layer_norm_eps (float): Epsilon value for layer normalization. Default is 1e-5.
        gamma_init (float): Initialization value for gain in linear projections. Default is 1.0.
        device (Optional[Union[torch.device, str]]): Device for parameters. Default is None.
        dtype (Optional[torch.dtype]): Data type for parameters. Default is None.
        op (Optional[xops.AttentionOp]): Attention operation. Default is None.
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dilation_rates: Sequence[int],
        segment_lengths: Sequence[int],
        dropout: float = 0.0,
        bias: bool = False,
        layer_norm: bool = True,
        layer_norm_eps: float = 1e-5,
        gamma_init: float = 1.0,
        device: Optional[Union[torch.device, str]] = None,
        dtype: Optional[torch.dtype] = None,
        op: Optional[xops.AttentionOp] = None,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.layer_norm = layer_norm
        self.gamma_init = gamma_init

        if not embed_dim % self.num_heads == 0:
            raise ValueError(
                f"embed_dim ({embed_dim}) must be divisible by "
                f"num_heads ({num_heads})"
            )
        num_dilations = len(dilation_rates)
        num_segments = len(segment_lengths)
        if num_dilations != num_segments:
            raise ValueError(
                f"len(dilation_rates) ({num_dilations}) must be equal to "
                f"len(segment_lengths) ({num_segments})"
            )

        print(num_heads)
        print(embed_dim)
        print(dilation_rates)
        print(segment_lengths)
        head_dim = embed_dim // num_heads
        print(head_dim)

        if not head_dim % 8 == 0:

            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be divisible by 8"
            )
        if not head_dim <= 128:
            raise ValueError(
                f"head_dim (embed_dim / num_heads = {head_dim}) must be <= 128"
            )

        self.q_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.k_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.v_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )
        self.attention = DilatedAttention(
            segment_lengths=segment_lengths,
            dilation_rates=dilation_rates,
            attention_dropout=dropout,
            op=op,
        )
        self.norm: Optional[nn.LayerNorm] = None
        if layer_norm:
            self.norm = nn.LayerNorm(
                embed_dim, eps=layer_norm_eps, device=device, dtype=dtype
            )
        self.o_proj = nn.Linear(
            embed_dim, embed_dim, bias=bias, device=device, dtype=dtype
        )

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_normal_(self.q_proj.weight)
        if self.q_proj.bias is not None:
            nn.init.constant_(self.q_proj.bias, 0)
        nn.init.xavier_normal_(self.k_proj.weight)
        if self.k_proj.bias is not None:
            nn.init.constant_(self.k_proj.bias, 0)

        nn.init.xavier_normal_(self.v_proj.weight, gain=self.gamma_init)
        if self.v_proj.bias is not None:
            nn.init.constant_(self.v_proj.bias, 0)
        nn.init.xavier_normal_(self.o_proj.weight, gain=self.gamma_init)
        if self.o_proj.bias is not None:
            nn.init.constant_(self.o_proj.bias, 0)

    def forward(
        self, query: Tensor, key: Tensor, value: Tensor, is_causal: bool = False
    ) -> Tuple[Tensor, None]:
        """
        Forward pass of the MultiheadDilatedAttention module.

        Args:
            query (Tensor): Query tensor.
            key (Tensor): Key tensor.
            value (Tensor): Value tensor.
            is_causal (bool): Flag indicating if the attention is causal. Default is False.

        Returns:
            Tuple[Tensor, None]: Output tensor and None.
        """
        q = self.q_proj(query)
        k = self.k_proj(key)
        v = self.v_proj(value)

        q = rearrange(q, "b n (h d) -> b n h d", h=self.num_heads)
        k = rearrange(k, "b n (h d) -> b n h d", h=self.num_heads)
        v = rearrange(v, "b n (h d) -> b n h d", h=self.num_heads)
        x = self.attention(q, k, v, is_causal=is_causal)
        x = rearrange(x, "b n h d -> b n (h d)")

        if self.layer_norm:
            assert self.norm is not None
            x = self.norm(x)
        x = self.o_proj(x)

        return x, None

To do so and Repalce the Layer used Inherent base Class from Hugging face

import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaDecoderLayer
from transformers.modeling_utils import ModuleUtilsMixin

class CondensedLlamaAttention(LlamaAttention):
    def __init__(self, config: CondensedLlamaConfig,layer_idx=None):
        super().__init__(config)

        self.LongNetAttention = MultiheadDilatedAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.dilation_rates,
            config.segment_lengths
        )
        self.is_causal = config.is_causal


    def forward(self, input, is_causal=None):
        if is_causal is None:
            is_causal = self.is_causal
        x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
        return x


class CondensedLlamaDecoderLayer(LlamaDecoderLayer):

    def __init__(self, config: CondensedLlamaConfig, layer_idx=None):  # Add layer_idx as an argument
        super().__init__(config, layer_idx=None)  # Pass layer_idx to the parent class constructor
        # Replace self_attn with your new attention module
        self.self_attn = MultiheadDilatedAttention(
            config.hidden_size,
            config.num_attention_heads,
            config.dilation_rates,
            config.segment_lengths
        )
        self.is_causal = config.is_causal


    def forward(self, input, is_causal=None):
        if is_causal is None:
            is_causal = self.is_causal
        x, _ = self.LongNetAttention(input, input, input, is_causal=is_causal)
        return x


class CondensedLlamaModel(LlamaModel):
    def __init__(self, config: CondensedLlamaConfig):
        super().__init__(config)

        self.layers = nn.ModuleList([CondensedLlamaDecoderLayer(config,layer_idx=None) for _ in range(config.num_hidden_layers)])
        # Initialize weights and apply final processing
        self.post_init()

Notation: As long as is_causal=None the learning of the Attention Mechanism is not masked which leads int Fully Learning Representation to produce the Embedding Space of Vectors of Tokens which means the Encoder-Only learns the feature Representation relevant between Tokens attended to Druing Dot-Product Similarity instead of `Decoder-Only used Masked-Attention which I am not interested to use at the point

Step 4 : ReConstructed the Model using Adjustment Config Class I did the following
Notation: i adjusted num_hidden_layers only for show case config.num_hidden_layers = 2 the origin param is num_hidden_layers=32

config.num_hidden_layers = 2
model_1 = CondensedLlamaModel(config)
model_1

Notation: i didn’t use Rotary Embedding Because of Attention used is Linear

Q 1 Correct me Please if i need to keep Rotary Embedding in my Encoder-Only
Output:

CondensedLlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-1): 2 x CondensedLlamaDecoderLayer(
      (self_attn): MultiheadDilatedAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (attention): DilatedAttention()
        (norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)

Finally Step: Transfer Learning The Weights Layers following ["q_proj", "k_proj", "v_proj", "o_proj"] From Decoder-Only to `Encoder-Only``

Here Comparing the New Encoder-Only with Decoder-Only

Decoder-Only used in CodeLlama

(0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()

Encoder-Only used in CodeLlama with Adujsment i did

CondensedLlamaDecoderLayer(
      (self_attn): MultiheadDilatedAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (attention): DilatedAttention()
        (norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
      )

both are has similar linear Layers in the following ["q_proj", "k_proj", "v_proj", "o_proj"]

the code i built to do Transfering the Weights

import torch
module_patterns_to_transfer = ["q_proj", "k_proj", "v_proj", "o_proj"]
def transfer_weights(original_model, custom_model, module_patterns_to_transfer):
    original_dict = original_model.state_dict()
    custom_dict = custom_model.state_dict()

    # Filter and transfer weights for specified layers
    for key in custom_dict.keys():
        for pattern in module_patterns_to_transfer:
            if pattern in key:
                if key in original_dict:
                    # Transfer weights
                    with torch.no_grad():
                        custom_dict[key].copy_(original_dict[key])

    # Load the updated state dictionary to the model
    custom_model.load_state_dict(custom_dict)

# Transfer weights from the original model to the model
transfer_weights(model, model_1, module_patterns_to_transfer)

# transferred weights in the custom model
for key, parameter in model_1.state_dict().items():
    print(key)
    print(parameter.size())
    print(parameter)

Output

embed_tokens.weight
torch.Size([32000, 4096])
tensor([[-0.0052,  0.0353,  0.0152,  ..., -0.0285,  0.0035,  0.0149],
        [ 0.0018, -0.0054,  0.0005,  ...,  0.0048,  0.0319,  0.0018],
        [ 0.0238,  0.0032, -0.0004,  ...,  0.0171, -0.0069, -0.0232],
        ...,
        [ 0.0084, -0.0174,  0.0109,  ...,  0.0083,  0.0139, -0.0389],
        [-0.0012, -0.0267,  0.0011,  ...,  0.0287,  0.0102, -0.0176],
        [ 0.0023,  0.0041,  0.0118,  ...,  0.0253,  0.0198, -0.0259]])
layers.0.self_attn.q_proj.weight
torch.Size([4096, 4096])
tensor([[ 1.8845e-03,  7.0190e-04, -5.3406e-03,  ...,  5.7373e-03,
          5.5847e-03,  2.2650e-05],
        [ 7.2937e-03, -5.8594e-03,  4.7607e-03,  ..., -7.3242e-03,
         -7.1106e-03, -9.9945e-04],
        [-1.4282e-02,  6.2561e-03,  8.5831e-04,  ...,  6.0120e-03,
          9.8267e-03,  1.0986e-03],
        ...,
        [ 1.9531e-02, -4.6692e-03,  1.1841e-02,  ...,  1.6602e-02,
         -1.3550e-02,  2.7847e-04],
        [-1.2512e-02,  8.5449e-04, -6.8665e-03,  ..., -2.1362e-02,
         -2.0142e-02, -6.6528e-03],
        [ 5.8289e-03,  3.7231e-03,  5.7068e-03,  ...,  9.5215e-03,
          7.0496e-03, -4.0588e-03]])
layers.0.self_attn.k_proj.weight
torch.Size([4096, 4096])
tensor([[ 1.4404e-02,  1.4221e-02, -2.3804e-03,  ...,  4.3640e-03,
         -1.1475e-02, -9.7046e-03],
        [-3.0396e-02, -3.4485e-03,  4.4250e-03,  ..., -8.4229e-03,
          1.2390e-02,  1.2512e-02],
        [ 1.0071e-03, -1.5747e-02,  1.7090e-03,  ...,  9.8877e-03,
          8.0109e-04, -8.6670e-03],
        ...,
        [ 5.7373e-03,  4.3030e-03,  9.9945e-04,  ..., -2.8839e-03,
          4.0894e-03,  5.0964e-03],
        [-3.6316e-03,  2.1057e-03, -5.7678e-03,  ...,  4.1723e-07,
          4.6082e-03, -1.1108e-02],
        [ 2.7313e-03,  3.7231e-03,  1.5488e-03,  ...,  2.7313e-03,
         -9.8877e-03,  6.1035e-03]])
layers.0.self_attn.v_proj.weight
....
....

Please Correct me if missed understanding anything because i got bad feedback from CEO during this Process and i told him i was Correct and right to Transform the CodeLlama to be Encoder-Only to learn the Embedding

Thank you so much for advance