GPU- clear variable after use

You can try this snippet:

import gc
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()

and read: