def forward(self,x2):
image, clinical_data = x2
clinical_data = torch.stack(clinical_data,dim=1).squeeze(0) #(torch.arange(clinical_data), dim=1)
x1 = self.modelA(image) #imaging_data_pred
x3 = torch.cat((x1, clinical_data), dim=2)
x2 = self.modelB(F.relu(x3)) #‘combined_pred’
#x = torch.cat((x1, x2), dim=1)
return x2
def training_step(self, batch):
x = (batch[self.hparams["weighting"]],(batch[key] for key in self.keys_of_clinical))
y = batch["Class"]
y_hat = self.predict(x)
loss = F.cross_entropy(y_hat, y)
with torch.no_grad():
y_hat = F.softmax(y_hat, dim=1)
return {'loss': loss, "train_y_hat": y_hat, "train_y": y}
How do I resolve this error?