Hi,
I think this can help:
list_of_tensors = [torch.randn(1, 2, 2), torch.randn(1, 2, 2), torch.randn(1, 2, 2)]
b = torch.Tensor(len(list_of_tensors), 2, 2)
b = torch.cat(list_of_tensors, dim=0)
print(b.shape)
b = b.cuda()
print(b.device)
Please see this thread:
Bests