Write dataset class to have 2 dataset output

I have followed your advice and created a dataset class which gives me the dataset in the form I think I need. See below:

    class KeypointDataset(Dataset):
        def __init__(self, csv_file, root_dir, 
                     transform1=transform, transform2=transform):
            
            self.pose_frame = pd.read_csv(csv_file, header=None)
            self.images = self.pose_frame.iloc[:,0]
            self.classes = self.pose_frame.iloc[:,1]
            self.pose_kp = self.pose_frame.iloc[:,2:]
            self.root_dir = root_dir
            self.transform1 = transform1
            self.transform2 = transform2            
            
        def __len__(self):
            return len(self.pose_frame)
        
        def __getitem__(self, idx):
            #load image
            img_name = os.path.join(self.root_dir, self.images[idx])
            img_name = img_name.replace('\\','/')
            image = io.imread(img_name)
            print(image.shape)
            #apply transform on image
            img_as_tensor = self.transform1(image)
            
            #image class
            image_class = self.classes[idx]
            
            #pose keypoints
            keypoints = self.pose_kp.iloc[idx, :].as_matrix()
            keypoints = keypoints.astype('float').reshape(-1,2)
            #keypoints to tensor
            keypoints_as_tensor = self.transform2(keypoints)
            
            return (img_as_tensor, image_class), (keypoints_as_tensor, img_name) 

My question now, is how do I break up the inputs for images and keypoints in the training loop. In the example you showed in Concatenate dataset, I understand the concept of adding the second input at the fc layer with x1 and x2. But what I don’t know, is how to tell my training loop to load these as two inputs, and how to identify x1 and x2 as images and keypoints respectively.
My training loop currently looks like this:

def train_model(model, criterion, optimizer, num_epochs):
        
        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0
        
        since = time.time()
        
        history = []

        for epoch in range(num_epochs):
    
            running_loss = 0.0
            total_train = 0
            correct_train = 0
            
            #iterate over data
            for i, data in enumerate(train_loader, 0):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)
                    
                optimizer.zero_grad()
                
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
            
                #accuracy            
                running_loss += loss.item()
                total_train += labels.nelement() #number of pixels in batch
                correct_train += (predicted == labels).sum().item()
                
            epoch_loss = running_loss / len(train_loader.dataset)
            epoch_acc = correct_train / total_train
            
            print(str('Epoch '+ str(epoch) +' Training Loss: {:.4f}  Training Accuracy {:.4f}'.format(
                    epoch_loss, epoch_acc)))
            
            history.append([epoch_loss, epoch_acc])
            
            if epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        print()
        
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
                time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))
        
        #format history
        history = pd.DataFrame(
                history,
                columns=['train_loss','train_acc'])
        
        torch.save(model.state_dict(), save_file_name)
        #load best model weights
        model.load_state_dict(best_model_wts)
        return model, history

Many thanks for your help so far :grinning: