I’ve read this discussion talking about torch.cat and linear layer(you answered)
import torch
from torch import nn
x1 = torch.randn(8, 200, 100, 200, device="cuda")
x2 = torch.rand_like(x1)
x3 = torch.rand_like(x1)
print(torch.cuda.memory_allocated() / 1024 ** 3)
# 0.35762786865234375
y = torch.cat((x1, x2, x3), dim=-1)
print(torch.cuda.memory_allocated() / 1024 ** 3)
# 0.7152557373046875
print(y.device)
print(y.shape)
temp = nn.Linear(600, 1)
temp.cuda(0)
print(y.device, x1.device, x2.device, x3.device)
k = temp(y)
print(k.shape)
print(torch.cuda.memory_allocated() / 1024 ** 3)
# 0.73
I reused your code to check if the linear layer is the cause, but it wasn’t.
I feel like I’ve tried almost everything to solve this problem.