Gpu memory leak on torch.cat

I’ve read this discussion talking about torch.cat and linear layer(you answered)

import torch
from torch import nn

x1 = torch.randn(8, 200, 100, 200, device="cuda")

x2 = torch.rand_like(x1)

x3 = torch.rand_like(x1)

print(torch.cuda.memory_allocated() / 1024 ** 3)

# 0.35762786865234375

y = torch.cat((x1, x2, x3), dim=-1)

print(torch.cuda.memory_allocated() / 1024 ** 3)

# 0.7152557373046875

print(y.device)

print(y.shape)

temp = nn.Linear(600, 1)

temp.cuda(0)

print(y.device, x1.device, x2.device, x3.device)

k = temp(y)

print(k.shape)

print(torch.cuda.memory_allocated() / 1024 ** 3)

# 0.73

I reused your code to check if the linear layer is the cause, but it wasn’t.

I feel like I’ve tried almost everything to solve this problem. :sob: