PyTorch tensor.to(device) for a List of Dict

I am working on an image object detection application using PyTorch torchvision.models.detection.fasterrcnn_resnet50_fpn. As indicated by the documentation, during training phase, the input to fasterrcnn_resnet50_fpn model should be:

- list of image tensors, each of shape [C, H, W]
- list of target dicts, each with:
    - boxes (FloatTensor[N, 4]): the ground-truth boxes in [x1, y1, x2, y2] format, 
                                 with values between 0 and H and 0 and W
    - labels (Int64Tensor[N]): the class label for each ground-truth box

I am expecting training in mini-batches, so there should be more than 1 image tensor and 1 target dict in each list.

With the two lists prepared, I can train with my fasterrcnn_resnet50_fpn model:

model.train()
for image_list, target_list in dataset_loader:
    """
    image_list: list of image tensors
    target_list: list of dicts {boxes, labels}
    """
    # some steps before ...

    # Feed inputs to model in training phase
    outputs = model(image_list, target_list)

    # more steps after ...

Everything works well, but there is one issue: by default, the tensor stored in cpu but I would like to train with gpu.

Naively, I can apply iterate through the lists and apply .to(device) for each tensor:

image_list = [ t.to(device) for t in image_list ]
target_list = [ {'boxes':d['boxes'].to(device), 'labels':d['labels']} for d in target_list ]

However, I doubt it would be the most memory/computation efficient, as I am only looping through the lists.

Therefore, my question is: are there better methods to apply tensor.to(device) to list of tensors, or list of dicts, with better memory/computation efficiency, and probably better readability?

Hi,

I’m afraid there is None :confused:
You can use something like this to make it cleaner but that’s pretty much it:

def move_to(obj, device):
  if torch.is_tensor(obj):
    return obj.to(device):
  elif isinstance(obj, dict):
    res = {}
    for k, v in obj.items():
      res[k] = move_to(v, device)
    return res
  elif isinstance(obj, list):
    res = []
    for v in obj:
      res.append(move_to(v, device))
    return res
  else:
    raise TypeError("Invalid type for move_to")
2 Likes

It would be nice if there is such a function with PyTorch. Your solution looks good. Thank you so much!

1 Like