Different memory usage between JupyterLab and script

I’m trying to train a BERT model using ‘sentence-transformers/distiluse-base-multilingual-cased-v1’ as the base. While inspecting nvidia-smi I’ve noticed that loading the model with ‘AutoModel.from_pretrained()’ takes up 400MB more GPU memory when done in script. The larger issue, however, is that when I calculate text embeddings in eval() mode with torch.no_grad() in my script, it results in a strange memory footprint.

The following code, when executed through the script (this doesn’t happen when executed in jupyterlab), allocates an extra 300MB of GPU memory on top of the model:

def calculate_embeddings(self, dataset):
    self.model.eval()
    batch_size = self.batch_size
    tokens = self.tokenizer([article['text'] for article in dataset], padding='max_length', truncation=True, return_tensors='pt')
    input_ids = tokens['input_ids']
    attention_mask = tokens['attention_mask']

    embeddings = torch.empty([0, 768]).cpu()
    with torch.no_grad():
        for i in tqdm(range((len(input_ids)//batch_size)+1), desc="Calculating embeddings"):
            batch_ids = input_ids[i*batch_size:i*batch_size+batch_size].to(self.model.device)
            batch_mask = attention_mask[i*batch_size:i*batch_size+batch_size].to(self.model.device)
            
            embeddings = torch.cat([embeddings, self.model(input_ids=batch_ids, attention_mask=batch_mask).detach().cpu()], 0)
            
            del batch_ids
            del batch_mask
            torch.cuda.empty_cache()
    del tokens
    del input_ids
    del attention_mask
    gc.collect()
    torch.cuda.empty_cache()
    return embeddings

I’ve tried to analyse the GPU usage with this function:

def pretty_size(size):
    """Pretty prints a torch.Size object"""
    assert(isinstance(size, torch.Size))
    return " × ".join(map(str, size))

def dump_tensors(gpu_only=True):
    """Prints a list of the Tensors being tracked by the garbage collector."""
    import gc
    total_size = 0
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj):
                if not gpu_only or obj.is_cuda:
                    print("%s:%s%s %s" % (type(obj).__name__, 
                                          " GPU" if obj.is_cuda else "",
                                          " pinned" if obj.is_pinned else "",
                                          pretty_size(obj.size())))
                    total_size += obj.numel()
            elif hasattr(obj, "data") and torch.is_tensor(obj.data):
                if not gpu_only or obj.is_cuda:
                    print("%s → %s:%s%s%s%s %s" % (type(obj).__name__, 
                                                   type(obj.data).__name__, 
                                                   " GPU" if obj.is_cuda else "",
                                                   " pinned" if obj.data.is_pinned else "",
                                                   " grad" if obj.requires_grad else "", 
                                                   " volatile" if obj.volatile else "",
                                                   pretty_size(obj.data.size())))
                    total_size += obj.data.numel()
        except Exception as e:
            pass        
    print("Total size:", total_size)

But the only difference I see is the print order of different GPU parameters.

Edit: These extra 300MB are added at the first iteration of the ‘for’ cycle, and remain constant

Could you check if both PyTorch versions are the same in both environments via print(torch.__version__), please?

Both are ‘1.13.1+cu117’

@ptrblck any idea what may cause that? The same thing happens when I run the debugger from remote in pycharm, the model weights 400MB less and there’s no calculation footprint

No, I don’t know what might be causing it as I haven’t this behavior before (but I also usually don’t use Jupyter). One debugging step I would try is to disable lazy module loading from the CUDA driver, which is on by default in PyTorch using CUDA >= 11.7.
You can disable it via export CUDA_MODULE_LOADING=EAGER, which will load all kernels into the CUDA context again (as was the previous behavior). If this yields the same (higher) memory usage, it would mean that your Jupyter environment somehow interacts with env variables.

1 Like

Apparently it was exactly that. Thank you!

Thanks for checking! It’s interesting to see that Jupyter somehow doesn’t respect these env variables (or maybe replaces them) which also explains why users often struggle to set CUDA_LAUNCH_BLOCKING=1 in their Jupyter environment.

1 Like