Best practice to handle batches caching with GPU

Hi I am learning the best ways to manage batches and other best practices during model training and inference and I have some:

  • If I have a batch that I move to GPU, should. I move it back to CPU after doing training? if no why?

    batch, label =,
    ..Training pass... 
    batch, label = batch.cpu(), label.cpu()

    If I cache my data in my Dataset class how can I ensure I can reuse the same batches on gpu to avoid transferring from and to cpu multiple times?

  • Is to a best practice to calculate all performance metrics with detached tensors to avoid storing other graph nodes in memory?

Thanks in advance for any possible help.

  1. Usually you would neither move the batch back to the CPU not cache it, as often you are applying random transformations on-the-fly. However, if your use case is different, you could of course try to store the data in a manual cache and try to reuse it. In this case, moving it back to the CPU would save GPU memory and would most likely be needed (unless you are working with a tiny dataset and model and you GPU is able to store both).

  2. Yes, I would recommend to detach() the tensors just in case 3rd party libraries would store the tensors in any way, which could increase the GPU memory since the entire computation graph would also be stored. Often you are calculating the metrics from the validation set, so you could also wrap the entire validation processing in a with torch.no_grad() block.

Thanks for your answer! I am working with video data which are quite heavy. As you see in my custom VideoDataset below, I am using a caching dictionary in my __get_item__ method.

When the batch is invoked by the DataLoader, the cached item would be returned. I wonder if by keeping transferring the batch to GPU would result in a OOM.

Isn’t PyTorch already flushing the unused batches in GPU? Or do you think it is a better idea to transfer the items to GPU within the `Dataset’ class, even though I read it is not a great idea?
Thanks for your hints!

class VideoDataset(

  def __init__(self, data_path, classes, caching_dict=None, transforms=None, n_frames=None):
    super(VideoDataset, self).__init__()

    self.data_path = data_path
    self.classes = classes

    self.transforms = transforms
    self.n_frames = n_frames

    self.caching_dict = caching_dict

  def read_video(self, path):

    vr = VideoReader(path)

    if len(vr) >= self.n_frames:

      idxs = np.linspace(0, len(vr), self.n_frames, endpoint=False,
      frames = []
      for i in idxs:
        if self.transforms:

      return torch.stack(frames, dim=0)
      print(f'N frames specified: {self.n_frames} must be less than number: {len(vr)} of video frames for video: {path}. Skipping video')

  def __getitem__(self, index):

    v_path, label = self.data_path[index]
    if self.caching_dict is not None:
      if index not in self.caching_dict:
        self.caching_dict[index] = self.read_video(v_path), self.classes[label]
      return self.caching_dict[index]
      return self.read_video(v_path), self.classes[label]

  def __len__(self): return len(self.data_path)

I don’t know what “flushing the unused batches in GPU” means, but in case you are wondering when tensors are deleted: once all references are deleted, the tensor will be released.
If you want to transfer the data to the GPU inside the Dataset, you would have to make sure multiprocessing works fine in this setup, and I would recommend to check the actual performance (once it’s working) to check, if you would see any speedups.

Yes I mean when the batch is released from the GPU, when the reference is reassigned during the for loop I supposed the previous batch is released right?
Thanks for your help.

Yes, that is right, as long as you wouldn’t reference the previous tensor in any other way.

1 Like