Hi everyone,
Let’s assume we have a tensor h_i over dimension F, a weight matrix W of dimension FxF and a function a from FxF → 1. I represented a as a vector of 2Fx1.
I would like to compute .
Currently, I’m doing
class MyModule(nn.Module):
def __init__(self, in_features, out_features):
super(MyModule, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
def forward(self, input):
h = torch.mm(input, self.W)
N = h.size()[0]
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
e = F.elu(torch.matmul(a_input, self.a).squeeze(2))
return e
This is working, but as you can imagine it takes ~7GB of memory on my GPU. If I would like to run multiple instances in my GPU, I can’t due to the memory usage.
Would you see a way to do this operation more efficiently ? I really think that using repeat “waste” a lot of memory. Maybe there is a clever way to handle this ?
Thank you for your answers !