I have a tensor a
:
a = torch.randn(3, 512)
and I have a function concat
.
I want concat(t1, t1) then concat(t1, t2) then concat(t1, t3) then concat(t2, t1) …
I can easily do this by running a for
loop. But this method is taking 24 hours for one epoch.
The second method is following:
t1 = a.unsqueeze(0).expand(a.size()[0], -1, -1).contiguous().view(a.size()[0] * a.size()[0], -1)
t2 = a.unsqueeze(1).expand(-1, a.size()[0], -1).contiguous().view(a.size()[0] * a.size()[0], -1)
concat(t1, t2)
But this is giving memory error.
If there is any easier way then please let me know.