Correct batch data format for object detection like SSD

Hey,

the last days I tried to train a SSD network on customer dataset. In my case it is WIDERFACES. I always run into an issue when I set my batch size higher than 1. The issue comes from the targets.

I use the following shape:

# batch = 2
{
'labels': tensor([[1],[1]]), 
'boxes': tensor([[142.9688,  27.9485, 181.6406,  81.4668],[ 84.9609, 156.5217, 238.7695, 293.1257]])
}

if I start to train the network I run into the following issue:

**~\AppData\Local\Programs\Python\Python39\lib\site-packages\torchvision\models\detection\transform.py** in forward**(self, images, targets)** 113 **for** i **in** range**(** len**(** images**)** **)** **:** 114 image **=** images**[** i**]** **--> 115** target_index **=** targets**[** i**]** **if** targets **is** **not** **None** **else** **None** 116 target_index **=** **None** 117 **if** image**.** dim**(** **)** **!=** **3** **:** **IndexError** : list index out of range

If I go with the batchsize=1 the training works perfect.

Here is my code:


# PREPARE DATA
from torch.utils.data import Dataset
from torch import IntTensor, LongTensor, FloatTensor
from PIL import Image, ImageDraw
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

transform = A.Compose([
    A.Resize(width=300, height=300),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=False, p=1.0),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids'])) 

class wildfaces(Dataset):
    def __init__(self, root_dir, pd_df, transform=None):
        self.root_dir = root_dir
        self.df = pd_df
        self.transform = transform
        self.img_path = self.df['img_path']
        self.boundingbox = self.df['pascal_bb']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.img_path.iloc[idx])
        image =  Image.open(img_path)
        image = np.array(image)
        boundingbox = self.boundingbox.iloc[idx]
        boundingbox = LongTensor(boundingbox)
        if self.transform:
            transformed = self.transform(image=image, bboxes=[boundingbox], category_ids=[1])
        target = {}
        target["labels"] = torch.ones((1), dtype=torch.int64) #IntTensor([1]) #
        target["boxes"] = FloatTensor(transformed["bboxes"][0])
        return transformed["image"], target

# LOAD DATA
from torch.utils.data import DataLoader
training_data = wildfaces("./data/face-recognition/wider-face/WIDER_train/images",df[:20], transform)
train_dataloader = DataLoader(training_data, batch_size=2, shuffle=True)


# TRAIN DATA
import torch.optim as optim
import torch
import torchvision
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = torchvision.models.detection.ssd300_vgg16(pretrained= False, progress=True, pretrained_backbone=5, trainable_backbone_layers=False, num_classes=2)
parameters = filter(lambda p: p.requires_grad, net.parameters())
optimizer = optim.RMSprop(parameters, lr=0.001)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=10)

num_epochs = 3
net.to(device)

for epoch in range(num_epochs):  # loop over the dataset multiple times
    net.train()
    total_losses = 0
    for i, data in enumerate(train_dataloader):
        images, targets = data

        images = images.float()
        images = images.to(device)
        targets["labels"] = targets["labels"].squeeze(1).to(device)
        targets["boxes"] = targets["boxes"].to(device)
        
        loss_dict = net(images,[targets])
        losses = loss_dict["bbox_regression"].sum() + loss_dict["classification"].sum()
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_losses = total_losses + losses.item()

    print("Loss: {} Lr: {}".format(total_losses/i, optimizer.param_groups[0]['lr']))
    scheduler.step(total_losses/20) #4631

print('Finished Training')