hi @ptrblck
Could you help me with a similar issue…
I am trying to train T5Encoder(finetuned on QA) alongwith T5(encoder + decoder finetuned on summarization). I’m using a mixed objective. Sequence(summary) generation and binary class prediction whether summary contains answer or not. When I try and pass the target 0/1 class through DataLoader I think it tries to embed it and I see this error :
<ipython-input-27-c1d41cd4adff> in validation_step(self, batch, batch_idx)
89 question_input_ids=question_input_ids,
90 question_attention_mask=question_attention_mask,
---> 91 question_labels=question_labels.to(torch.float64)
92 )
93
RuntimeError: Expected tensor for argument
1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.DoubleTensor instead (while checking arguments for embedding)
Here question_labels are the 0/1 target labels
Model →
class CrossAttentionSummarizer(pl.LightningModule):
def __init__(self):
super(CrossAttentionSummarizer, self).__init__()
self.summarizer_model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
self.qa_encoder = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True)
self.multihead_attn = nn.MultiheadAttention(embed_dim=768, num_heads=4, batch_first=True)
self.linear1 = nn.Linear(512*768, 512)
self.linear2 = nn.Linear(512, 1, bias=False)
self.sigmoid = nn.Sigmoid()
self.bce_loss = nn.BCELoss()
def forward(self, question_input_ids, question_attention_mask, question_labels, input_ids, attention_mask, decoder_attention_mask, labels=None):
summarizer_output = self.summarizer_model(
input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask
)
qa_output = self.qa_encoder(
question_input_ids,
question_attention_mask,
question_labels
)
decoder_output = summarizer_output[3]
encoder_output = qa_output[2]
multi_attn_output, multi_attn_output_weights = self.multihead_attn(decoder_output, encoder_output, encoder_output)
lin_output = self.linear1(multi_attn_output.reshape(-1, 512*768))
cls_outputs = self.linear2(lin_output)
cls_preds = self.sigmoid(cls_outputs)
cls_pred_loss = self.bce_loss(cls_preds, question_labels)
return summarizer_output.loss, summarizer_output.logits, cls_pred_loss, cls_preds
def training_step(self, batch, batch_idx):
input_ids = batch["text_input_ids"]
attention_mask = batch["text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
question_input_ids = batch["question_input_ids"]
question_attention_mask = batch["question_attention_mask"]
question_labels = batch["question_labels"]
loss, outputs, cls_pred_loss, cls_pred = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
question_input_ids=question_input_ids,
question_attention_mask=question_attention_mask,
question_labels=question_labels.to(torch.float64)
)
self.log("train_loss", loss, prog_bar=True, logger=True)
self.log("train_pred_loss", cls_pred_loss, prog_bar=True, logger=True)
return loss, cls_pred_loss
def validation_step(self, batch, batch_idx):
input_ids = batch["text_input_ids"]
attention_mask = batch["text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
question_input_ids = batch["question_input_ids"]
question_attention_mask = batch["question_attention_mask"]
question_labels = batch["question_labels"]
loss, outputs, cls_pred_loss, cls_pred = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
question_input_ids=question_input_ids,
question_attention_mask=question_attention_mask,
question_labels=question_labels.to(torch.float64)
)
self.log("val_loss", loss, prog_bar=True, logger=True)
self.log("val_pred_loss", cls_pred_loss, prog_bar=True, logger=True)
return loss, cls_pred_loss
def test_step(self, batch, batch_idx):
input_ids = batch["text_input_ids"]
attention_mask = batch["text_attention_mask"]
labels = batch["labels"]
labels_attention_mask = batch["labels_attention_mask"]
question_input_ids = batch["question_input_ids"]
question_attention_mask = batch["question_attention_mask"]
question_labels = batch["question_labels"]
loss, outputs, cls_pred_loss, cls_pred = self(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_attention_mask=labels_attention_mask,
labels=labels,
question_input_ids=question_input_ids,
question_attention_mask=question_attention_mask,
question_labels=question_labels.to(torch.float64)
)
self.log("test_loss", loss, prog_bar=True, logger=True)
self.log("test_pred_loss", cls_pred_loss, prog_bar=True, logger=True)
return loss, cls_pred_loss
def configure_optimizers(self):
return AdamW(self.parameters(), lr=0.0001)