How to know the memory allocated for a tensor on gpu?

How can we know what is the total memory allocated for a tensor on gpu? All the below statements return 72. Looks like I am missing something?

print(sys.getsizeof(torch.FloatTensor([0.5]).cuda()))
72
print(sys.getsizeof(torch.FloatTensor([0.5])))
72
print(sys.getsizeof(torch.FloatTensor([0.5, 0.7])))
72
print(sys.getsizeof(torch.FloatTensor([0.5, 0.7]).cuda()))
72

Or is it safe to calculate that if a float tensor is on gpu, then the memory consumed by the tensor in total is 4 bytes * length_of_tensor ?

It will be also useful to know how to calculate the memory consumed for any object in gpu. Eg: torchtext.data.dataset.Dataset

4 Likes

Hi,

sys.getsizeof() will return the size of the python object. It will the same for all tensors as all tensors are a python object containing a tensor.
For each tensor, you have a method element_size() that will give you the size of one element in byte. And a function nelement() that returns the number of elements.
So the size of a tensor a in memory (cpu memory for a cpu tensor and gpu memory for a gpu tensor) is a.element_size() * a.nelement().

All objects are store in cpu memory. The only thing that can be using GPU memory are tensors (from all pytorch objects). So the gpu memory used by whatever object is the memory used by the tensors on the gpu that it contains.

61 Likes

Thank you for the detailed reply @albanD! Regarding the other objects in the gpu, so are you saying that the functions which tracks the graph for each tensors etc are stored in cpu?

Yes, they don’t actually need to be on GPU (and can’t). Only Tensors need to be there to perform operations on them.

2 Likes

I have a followup question: How do you determine how much memory the computational graph of
a tensor occupies in memory? This is important to know sometimes for memory-bottlenecked networks as one can move the network to a less memory hungry model (if such model exists) but with comparable performance (for example GRU instead of LSTM).

2 Likes

Does nelement take into account duplicated memory in a view object? If I understand correctly the memory requirements for img and img_3_dims should be the same!

I.e. this this code

img.shape, type(img)
(torch.Size([256, 256]), <class 'torch.Tensor'>)
img.element_size(), img.nelement()
(4, 65536)
img_3_dims = img.view(img.shape + (1,)).expand(-1,-1,3)
img_3_dims.element_size(), img_3_dims.nelement()
(4, 196608)
1 Like

Hi,

The formula above only works for contiguous Tensors.
If you start to play with strides (what expand does) or slicing, it will be come very complex to know the memory usage. But such Tensors should be fairly rare in regular applications.

Thanks for clarification @albanD

I used the line of code below to convert a gray-scale image to a 3-channel image so I can use my data with a pretrained model (for RGB input), without using triple the memory. I was hoping to confirm I was saving memory using view() and expand(), but I can’t find a way to explicitly measure this? Otherwise, I’ll probably just assume it’s doing what I think :slight_smile:

img_3_dims = img.view(img.shape + (1,)).expand(-1,-1,3)
1 Like

Both expand and view are never allocating memory. So if you only use these, you can be sure that you don’t use extra memory.

2 Likes

Hi all,
so using the method given by @albanD we can, for example, know the size of the input and output tensor of the net. However sometimes inside the net we create some other tensors. In order to track the whole usage of GPU memory do we need to explicitly consider the sizes of all these tensors or is it possible to ask for memory usage in a more general way?

Hi,

You can see this section of the doc that presents all the methods we provide to inspect the memory usage on GPU.

5 Likes

I suppose you could force the tensors to be on your gpu (cuda) and then inspect the memory usage as mentioned in a later answer by @albanD in this section.

worked for me! (could confirm that although torch.repeat will allocate new memory, torch.expand won’t)

you can use pytorch profiler, which will provide you with detailed information regarding this.

Why not nelements()?

Are you referring to tensor.size().numel()?

If so, I imagine it refers to the size of the view of the underlaying data of a tensor. I am still looking for a way to get the exact allocated memory of the underlaying data of a tensor.

BUT, you can compute it yourself if you allocated the tensor yourself, or you know the size of the underlaying data. Pretty much multiply the number of elements originally allocated, by the size of each element, and round up by 512. I did some tests, and for some reason, pytorch allocates blocks of 512 bytes. The reserved memory increases by 2 MB.

For instance,

import torch
import math
from scripts.utils.memory import print_memory_info

def ceiling(x, factor=512):
  return math.ceil(x/factor)*factor

device = torch.device('cuda')

print_memory_info()

print(f'diff in memory_allocated will be => ceiling(64*4, 512) = {ceiling(64*4)}')
x1 = torch.randint(16, size=(64, ), dtype=torch.float32, device=device)
print_memory_info()

print(f'diff in memory_allocated will be => ceiling(128*4, 512) = {ceiling(128*4)}')
x2 = torch.randint(16, size=(128, ), dtype=torch.float32, device=device)
print_memory_info()

print(f'diff in memory_allocated will be => ceiling(200*4, 512) = {ceiling(200*4)}')
x3 = torch.randint(16, size=(200, ), dtype=torch.float32, device=device)
print_memory_info()

print(f'diff in memory_allocated will be => ceiling(256*4, 512) = {ceiling(256*4)}')
x4 = torch.randint(16, size=(256, ), dtype=torch.float32, device=device)
print_memory_info()

memory.py

import torch

_cache_ = {
  'memory_allocated': 0,
  'max_memory_allocated': 0,
  'memory_reserved': 0,
  'max_memory_reserved': 0,
}

def _get_memory_info(info_name, unit):

  tab = '\t'
  if info_name == 'memory_allocated':
    current_value = torch.cuda.memory.memory_allocated()
  elif info_name == 'max_memory_allocated':
    current_value = torch.cuda.memory.max_memory_allocated()
  elif info_name == 'memory_reserved':
    tab = '\t\t'
    current_value = torch.cuda.memory.memory_reserved()
  elif info_name == 'max_memory_reserved':
    current_value = torch.cuda.memory.max_memory_reserved()
  else:
    raise ValueError()

  divisor = 1
  if unit.lower() == 'kb':
    divisor = 1024
  elif unit.lower() == 'mb':
    divisor = 1024*1024
  elif unit.lower() == 'gb':
    divisor = 1024*1024*1024
  else:
    raise ValueError()

  diff_value = current_value - _cache_[info_name]
  _cache_[info_name] = current_value

  return f"{info_name}: \t {current_value} ({current_value/divisor:.3f} {unit.upper()})" \
         f"\t diff_{info_name}: {diff_value} ({diff_value/divisor:.3f} {unit.upper()})"

def print_memory_info(unit='kb'):

  print(_get_memory_info('memory_allocated', unit))
  print(_get_memory_info('max_memory_allocated', unit))
  print(_get_memory_info('memory_reserved', unit))
  print(_get_memory_info('max_memory_reserved', unit))
  print('')

Output

memory_allocated: 0 (0.000 KB) diff_memory_allocated: 0 (0.000 KB)
max_memory_allocated: 0 (0.000 KB) diff_max_memory_allocated: 0 (0.000 KB)
memory_reserved: 0 (0.000 KB) diff_memory_reserved: 0 (0.000 KB)
max_memory_reserved: 0 (0.000 KB) diff_max_memory_reserved: 0 (0.000 KB)

diff in memory_allocated will be => ceiling(64*4, 512) = 512
memory_allocated: 512 (0.500 KB) diff_memory_allocated: 512 (0.500 KB)
max_memory_allocated: 512 (0.500 KB) diff_max_memory_allocated: 512 (0.500 KB)
memory_reserved: 2097152 (2048.000 KB) diff_memory_reserved: 2097152 (2048.000 KB)
max_memory_reserved: 2097152 (2048.000 KB) diff_max_memory_reserved: 2097152 (2048.000 KB)

diff in memory_allocated will be => ceiling(128*4, 512) = 512
memory_allocated: 1024 (1.000 KB) diff_memory_allocated: 512 (0.500 KB)
max_memory_allocated: 1024 (1.000 KB) diff_max_memory_allocated: 512 (0.500 KB)
memory_reserved: 2097152 (2048.000 KB) diff_memory_reserved: 0 (0.000 KB)
max_memory_reserved: 2097152 (2048.000 KB) diff_max_memory_reserved: 0 (0.000 KB)

diff in memory_allocated will be => ceiling(200*4, 512) = 1024
memory_allocated: 2048 (2.000 KB) diff_memory_allocated: 1024 (1.000 KB)
max_memory_allocated: 2048 (2.000 KB) diff_max_memory_allocated: 1024 (1.000 KB)
memory_reserved: 2097152 (2048.000 KB) diff_memory_reserved: 0 (0.000 KB)
max_memory_reserved: 2097152 (2048.000 KB) diff_max_memory_reserved: 0 (0.000 KB)

diff in memory_allocated will be => ceiling(256*4, 512) = 1024
memory_allocated: 3072 (3.000 KB) diff_memory_allocated: 1024 (1.000 KB)
max_memory_allocated: 3072 (3.000 KB) diff_max_memory_allocated: 1024 (1.000 KB)
memory_reserved: 2097152 (2048.000 KB) diff_memory_reserved: 0 (0.000 KB)
max_memory_reserved: 2097152 (2048.000 KB) diff_max_memory_reserved: 0 (0.000 KB)

3 Likes

Another option from this StackOverflow answer is to measure the size of the underlying .storage()

import sys
import torch

t = torch.FloatTensor([0.5, 0.7])
print(sys.getsizeof(t.storage()))  # 64
t = t.cuda()
print(sys.getsizeof(t.storage()))  # 64
2 Likes
import torch

device = 'cuda:0'

# before
torch._C._cuda_clearCublasWorkspaces()
memory_before = torch.cuda.memory_allocated(device)

# your tensor
data = torch.randn((10000,100),device=device)

# after
memory_after = torch.cuda.memory_allocated(device)
latent_size = memory_after - memory_before

latent_size
# 4000256

get code from https://github.com/pytorch/pytorch/blob/ee28b865ee9c87cce4db0011987baf8d125cc857/torch/distributed/pipeline/sync/_balance/profile.py#L100-L99