Code review to improve training speed

I’m currently training a retinanet model for fashion apparel with pytorch. However I’m getting very slow training times (0.5s per step on 1333x1333 inputs).
I’m training this on a queue server which makes debugging very hard (also I have a job scheduling limit).
Can anyone help me with a code review?
The main bits I need reviews with is at collate_fn and with the hist object.

import torch

from torch_collections.models.retinanet import RetinaNet
from torch_collections.losses import DetectionFocalLoss, DetectionSmoothL1Loss

def main():
    # Load dataset
    #### INSERT YOUR DATASET HERE ####

    # dataset is a torch.utils.data.Dataset object
    # dataset should output samples of dict type with keys 'image' and 'annotations'
    # Here is what a sample would look like
    # sample = {
    #     'image'       : A numpy.ndarray image of dtype uint8 in RGB, HWC format and
    #     'annotations' : A numpy.ndarray of shape (number_of_annotations, 5)
    #                     Each annotation is of the format (x1, y1, x2, y2, class_id)
    # }
    dataset = YOUR_DETECTION_DATASET

    ##################################

    # Load model with initial weights
    #### INSERT NUMBER OF CLASSES HERE ####
    retinanet = RetinaNet(num_classes=NUMBER_OF_CLASSES)
    #######################################

    # Create dataset iterator
    dataset_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=retinanet.collate_fn,
        num_workers=2,
    )
    # The retinanet.collate_fn uses GPU to perform a rather costly transformation
    # Some tensors here are shared, my reasoning for using 2 workers is to read images
    # from files faster (some images are high-res)
    # This could possibly be a major road block as there could be needless IO from GPU to CPU
    # and back.

    # Initialize loss functions and optimizer
    focal_loss_fn = DetectionFocalLoss()
    huber_loss_fn = DetectionSmoothL1Loss()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, retinanet.parameters()), lr=0.00001)

    for epoch in range(50):
        dataset_iterator = dataset_loader.__iter__()
        hist = {
            'loss': [],
            'regression_loss': [],
            'classification_loss': []
        }

        for step_nb in range(10000):
            # Get sample and zero optimizer
            sample = dataset_iterator.next()
            optimizer.zero_grad()

            # forward
            outputs = retinanet(sample['image'])
            regression_loss = huber_loss_fn(outputs[0], sample['regression'])
            classification_loss = focal_loss_fn(outputs[1], sample['classification'])

            # If all anchors are negative then we just skip this epoch
            if classification_loss is None:
                continue

            # If there are no positive anchors we skip regression loss
            if regression_loss is None:
                regression_loss = 0

            loss = regression_loss + classification_loss

            # backward + optimize
            loss.backward()
            optimizer.step()

            # *** My main question is if storing step data like this counts as making
            # cuda copies (which I heard is an expensive process)
            hist['loss'].append(loss)
            hist['regression_loss'].append(regression_loss)
            hist['classification_loss'].append(classification_loss)

            # Clean batch
            # *** I am not sure if this is a required step but somehow I stopped
            # getting cuda runtime error (30) after this
            del loss
            del regression_loss
            del classification_loss
            del sample['image']
            del sample['regression']
            del sample['classification']
            del sample

        torch.save(retinanet, 'snapshot/epoch_{}.pth'.format(epoch))
        print('Epoch {} - loss: {} - regression: {} - classification: {}'.format(
            epoch,
            torch.sum(torch.stack(hist['loss'])).item(),
            torch.sum(torch.stack(hist['regression_loss'])).item(),
            torch.sum(torch.stack(hist['classification_loss'])).item()
        ))

    print('Finished Training')


if __name__ == '__main__':
    # This step became required for python 3.X to share tensors for multiprocessing
    torch.multiprocessing.set_start_method("spawn", force=True)
    main()