I am trying to train a graph neural net. Following is a snippet which reproduces my issue.
import torch
from torch import nn
n = 14500
r = 1000
d = 50
i = torch.LongTensor([[0, 1, 1],
[2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
_A = torch.sparse.FloatTensor(i, v, torch.Size([n, n])).cuda()
A = [k] * r
X = torch.rand(n, n).cuda()
W = nn.Parameter(torch.rand(r, n, d)).cuda()
out = torch.zeros(n, d).cuda()
for i, a in enumerate(A):
print(i)
emb_acc = torch.sparse.mm(a, X) # (n x n)
out += torch.matmul(emb_acc, W[i])
I want to be able to compute $o = \sum_{r} W_r A_r X$ and backprop w.r.t W.
Running this snippet on a GPU with around 12GB RAM, the memory gets full at around 7th iteration in the for loop. Is there an efficient way to achieve what I am trying to achieve?
Any help would be appreciated!