Custom Model with 2 GPT2 models from huggingface

I want to make a custom model. Custom model will include 2 GPT2 models but following modifications are needed at each attention layer.
self attention outputs of adjacent layers will get cross attended and cross attention output will go into input of next layer of 1 of the GPT2 models. lets say 2nd one for simplicity. while the 1st layer will get unmodified output from last layer. this has to happen at every layer.
I already wrote some code but i don’t think it is working as needed

import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Config
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

class CrossAttentionGPT2Model(nn.Module):
def init(self, config1,config2):
super(CrossAttentionGPT2Model, self).init()
self.config1 = config1
self.config2 = config2
self.wte1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.wte
self.wpe1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.wpe
self.wte2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.wte
self.wpe2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.wpe
self.drop1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.drop
self.drop2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.drop
self.self_attn_1 = nn.ModuleList([GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.h[n] for n in range(config1.n_layer)])
self.self_attn_2 = nn.ModuleList([GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.h[n] for n in range(config2.n_layer)])
self.cross_attn = nn.ModuleList([TransformerBlock(config2.n_embd, config2.n_head) for i in range(config2.n_layer)])
self.ln_f = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.ln_f
self.lm_head = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).lm_head
self.loss_fn = CrossEntropyLoss(ignore_index=-100)
self.dtype = torch.float16

def forward(self, input_ids1, input_ids2, attention_mask1=None,attention_mask2=None,labels=None):
    # Embed the input tokens using the pre-trained token embeddings
    input_shape1 = input_ids1.size()
    input_ids1 = input_ids1.view(-1, input_shape1[-1])
    batch_size1 = input_ids1.shape[0]
    input_shape2 = input_ids2.size()
    input_ids2 = input_ids2.view(-1, input_shape2[-1])
    batch_size2 = input_ids2.shape[0]
    device = input_ids2.device

    # Get the positional encodings and add to hidden states
    position_ids1 = torch.arange(0, input_shape1[-1], dtype=torch.long, device=device)
    position_ids2 = torch.arange(0, input_shape2[-1], dtype=torch.long, device=device)
    position_ids1 = position_ids1.unsqueeze(0)
    position_ids2 = position_ids2.unsqueeze(0)
    
    input_embeds1 = self.wte1(input_ids1)
    input_embeds2 = self.wte2(input_ids2)
    position_embeds1 = self.wpe1(position_ids1)
    position_embeds2 = self.wpe2(position_ids2)
    hidden_states1 = input_embeds1 + position_embeds1
    hidden_states2 = input_embeds2 + position_embeds2
    
    attention_mask1 = attention_mask1[:, None, None, :]
    attention_mask2 = attention_mask2[:, None, None, :]
    attention_mask1 = attention_mask1.to(dtype=self.dtype) 
    attention_mask2 = attention_mask2.to(dtype=self.dtype) 
    attention_mask1 = (1.0 - attention_mask1) * torch.finfo(self.dtype).min
    attention_mask2 = (1.0 - attention_mask2) * torch.finfo(self.dtype).min
    
    
    hidden_states1 = self.drop1(hidden_states1)
    hidden_states2 = self.drop2(hidden_states2)
    
    output_shape1 = (-1,) + input_shape1[1:] + (hidden_states1.size(-1),)
    output_shape2 = (-1,) + input_shape2[1:] + (hidden_states2.size(-1),)

    # Iterate through each custom transformer layer
    for layer1, layer2, cross_attn_layer in zip(self.self_attn_1, self.self_attn_2, self.cross_attn):
        attention_mask1 = attention_mask1.to(hidden_states1.device)
        attention_mask2 = attention_mask2.to(hidden_states2.device)
        out1 = layer1(hidden_states1, attention_mask=attention_mask1)
        out2 = layer2(hidden_states2, attention_mask=attention_mask2)
        hidden_states1 = out1[0]
        hidden_states2 = out2[0]
        out3 = cross_attn_layer(hidden_states2,hidden_states1)
        hidden_states2 = out3
        
    
    hidden_states2 = self.ln_f(hidden_states2)
    hidden_states2 = hidden_states2.view(output_shape2)
    hidden_states2 = self.lm_head(hidden_states2)
    
    
    loss = None
    if labels is not None:
        # Shift labels and final_output to the right to align with prediction
        shift_logits = hidden_states2[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        
    if loss is not None:
        return CausalLMOutputWithCrossAttentions(loss=loss, logits=hidden_states2)
    return CausalLMOutputWithCrossAttentions(logits=hidden_states2)