@smth I think that your method for finding all the tensor via Python’s garbage collector does not account for all tensors. I suppose that a corner case is for the backpropagation, when some tensor might be saved for the backward pass in a context and transformed (probably compressed in some way), hence they do not appear as tensors anymore. I wrote a method to account for the saved_tensors in the context for the backward pass. Could you please check if it extracts all the saved tensors correctly?
def get_tensors(only_cuda=False, omit_objs=[]):
"""
:return: list of active PyTorch tensors
>>> import torch
>>> from torch import tensor
>>> clean_gc_return = map((lambda obj: del_object(obj)), gc.get_objects())
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> device = torch.device(device)
>>> only_cuda = True if torch.cuda.is_available() else False
>>> t1 = tensor([1], device=device)
>>> a3 = tensor([[1, 2], [3, 4]], device=device)
>>> # print(get_all_tensor_names())
>>> tensors = [tensor_obj for tensor_obj in get_tensors(only_cuda=only_cuda)]
>>> # print(tensors)
>>> # We doubled each t1, a3 tensors because of the tensors collection.
>>> expected_tensor_length = 2
>>> assert len(tensors) == expected_tensor_length, f"Expected length of tensors {expected_tensor_length}, but got {len(tensors)}, the tensors: {tensors}"
>>> exp_size = (2,2)
>>> act_size = tensors[1].size()
>>> assert exp_size == act_size, f"Expected size {exp_size} but got: {act_size}"
>>> del t1
>>> del a3
>>> clean_gc_return = map((lambda obj: del_object(obj)), tensors)
"""
add_all_tensors = False if only_cuda is True else True
# To avoid counting the same tensor twice, create a dictionary of tensors,
# each one identified by its id (the in memory address).
tensors = {}
# omit_obj_ids = [id(obj) for obj in omit_objs]
def add_tensor(obj):
if torch.is_tensor(obj):
tensor = obj
elif hasattr(obj, 'data') and torch.is_tensor(obj.data):
tensor = obj.data
else:
return
if (only_cuda and tensor.is_cuda) or add_all_tensors:
tensors[id(tensor)] = tensor
for obj in gc.get_objects():
try:
# Add the obj if it is a tensor.
add_tensor(obj)
# Some tensors are "saved & hidden" for the backward pass.
if hasattr(obj, 'saved_tensors') and (id(obj) not in omit_objs):
for tensor_obj in obj.saved_tensors:
add_tensor(tensor_obj)
except Exception as ex:
pass
# print("Exception: ", ex)
# logger.debug(f"Exception: {str(ex)}")
return tensors.values() # return a list of detected tensors