I want to make a custom model. Custom model will include 2 GPT2 models but following modifications are needed at each attention layer.
self attention outputs of adjacent layers will get cross attended and cross attention output will go into input of next layer of 1 of the GPT2 models. lets say 2nd one for simplicity. while the 1st layer will get unmodified output from last layer. this has to happen at every layer.
I already wrote some code but i don’t think it is working as needed
import torch
from torch import nn
from transformers import GPT2LMHeadModel, GPT2Config
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class CrossAttentionGPT2Model(nn.Module):
def init(self, config1,config2):
super(CrossAttentionGPT2Model, self).init()
self.config1 = config1
self.config2 = config2
self.wte1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.wte
self.wpe1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.wpe
self.wte2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.wte
self.wpe2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.wpe
self.drop1 = GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.drop
self.drop2 = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.drop
self.self_attn_1 = nn.ModuleList([GPT2LMHeadModel.from_pretrained(‘/kaggle/working/models/model1’).transformer.h[n] for n in range(config1.n_layer)])
self.self_attn_2 = nn.ModuleList([GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.h[n] for n in range(config2.n_layer)])
self.cross_attn = nn.ModuleList([TransformerBlock(config2.n_embd, config2.n_head) for i in range(config2.n_layer)])
self.ln_f = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).transformer.ln_f
self.lm_head = GPT2LMHeadModel.from_pretrained(‘/kaggle/input/gp2-recommender/GPT2_recommender’).lm_head
self.loss_fn = CrossEntropyLoss(ignore_index=-100)
self.dtype = torch.float16
def forward(self, input_ids1, input_ids2, attention_mask1=None,attention_mask2=None,labels=None):
# Embed the input tokens using the pre-trained token embeddings
input_shape1 = input_ids1.size()
input_ids1 = input_ids1.view(-1, input_shape1[-1])
batch_size1 = input_ids1.shape[0]
input_shape2 = input_ids2.size()
input_ids2 = input_ids2.view(-1, input_shape2[-1])
batch_size2 = input_ids2.shape[0]
device = input_ids2.device
# Get the positional encodings and add to hidden states
position_ids1 = torch.arange(0, input_shape1[-1], dtype=torch.long, device=device)
position_ids2 = torch.arange(0, input_shape2[-1], dtype=torch.long, device=device)
position_ids1 = position_ids1.unsqueeze(0)
position_ids2 = position_ids2.unsqueeze(0)
input_embeds1 = self.wte1(input_ids1)
input_embeds2 = self.wte2(input_ids2)
position_embeds1 = self.wpe1(position_ids1)
position_embeds2 = self.wpe2(position_ids2)
hidden_states1 = input_embeds1 + position_embeds1
hidden_states2 = input_embeds2 + position_embeds2
attention_mask1 = attention_mask1[:, None, None, :]
attention_mask2 = attention_mask2[:, None, None, :]
attention_mask1 = attention_mask1.to(dtype=self.dtype)
attention_mask2 = attention_mask2.to(dtype=self.dtype)
attention_mask1 = (1.0 - attention_mask1) * torch.finfo(self.dtype).min
attention_mask2 = (1.0 - attention_mask2) * torch.finfo(self.dtype).min
hidden_states1 = self.drop1(hidden_states1)
hidden_states2 = self.drop2(hidden_states2)
output_shape1 = (-1,) + input_shape1[1:] + (hidden_states1.size(-1),)
output_shape2 = (-1,) + input_shape2[1:] + (hidden_states2.size(-1),)
# Iterate through each custom transformer layer
for layer1, layer2, cross_attn_layer in zip(self.self_attn_1, self.self_attn_2, self.cross_attn):
attention_mask1 = attention_mask1.to(hidden_states1.device)
attention_mask2 = attention_mask2.to(hidden_states2.device)
out1 = layer1(hidden_states1, attention_mask=attention_mask1)
out2 = layer2(hidden_states2, attention_mask=attention_mask2)
hidden_states1 = out1[0]
hidden_states2 = out2[0]
out3 = cross_attn_layer(hidden_states2,hidden_states1)
hidden_states2 = out3
hidden_states2 = self.ln_f(hidden_states2)
hidden_states2 = hidden_states2.view(output_shape2)
hidden_states2 = self.lm_head(hidden_states2)
loss = None
if labels is not None:
# Shift labels and final_output to the right to align with prediction
shift_logits = hidden_states2[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if loss is not None:
return CausalLMOutputWithCrossAttentions(loss=loss, logits=hidden_states2)
return CausalLMOutputWithCrossAttentions(logits=hidden_states2)