Any idea to make this operation using less memory?

Hi everyone,

Let’s assume we have a tensor h_i over dimension F, a weight matrix W of dimension FxF and a function a from FxF -> 1. I represented a as a vector of 2Fx1.

I would like to compute image.

Currently, I’m doing

class MyModule(nn.Module):

    def __init__(self, in_features, out_features):
        super(MyModule, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
        self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)

    def forward(self, input):
        h = torch.mm(input, self.W)
        N = h.size()[0]

        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = F.elu(torch.matmul(a_input, self.a).squeeze(2))
        return e

This is working, but as you can imagine it takes ~7GB of memory on my GPU. If I would like to run multiple instances in my GPU, I can’t due to the memory usage.

Would you see a way to do this operation more efficiently ? I really think that using repeat “waste” a lot of memory. Maybe there is a clever way to handle this ?

Thank you for your answers !

Could you explain a bit more about h_j and what it is supposed to be?
Some how I don’t understand, how a_input is built.

Also, if I run this code snippet, output has a size of [batch_size, 10]. Is this the right behavior, given out_features=2?

in_features = 5
out_features = 2
model = MyModule(in_features, out_features)

x = Variable(torch.randn(10, in_features).cuda())
output = model(x)

Yes sure.

In this cases, h tensors are the last hidden states of a RNN.

I built a_input as follow (I propose an example instead).

In [8]: import torch

In [9]: import numpy as np

In [10]: h = torch.LongTensor(np.array([[1,1], [2,2], [3,3]]))

In [11]: N=3

In [12]: h.repeat(1, N).view(N * N, -1)
Out[12]: 

    1     1
    1     1
    1     1
    2     2
    2     2
    2     2
    3     3
    3     3
    3     3
[torch.LongTensor of size 9x2]

In [13]: h.repeat(N, 1)
Out[13]: 

    1     1
    2     2
    3     3
    1     1
    2     2
    3     3
    1     1
    2     2
    3     3
[torch.LongTensor of size 9x2]

Is it clearer ?

The behavior is exactly as you describe. In my snippet, you receive a batch of hidden representations (one row per sample) and should compute the attention among the whole batch.

Thank you for your help !

PS: Here is the original paper https://arxiv.org/pdf/1710.10903.pdf
PS2: My Pytorch implementation of this paper is available in https://github.com/Diego999/pyGAT, which produces results as expected.

Anyone having an idea or another proposition to implement this operator ?