Running out of RAM while accumulating sum

I am trying to train a graph neural net. Following is a snippet which reproduces my issue.

import torch
from torch import nn

n = 14500
r = 1000
d = 50

i = torch.LongTensor([[0, 1, 1],
                      [2, 0, 2]])
v = torch.FloatTensor([3, 4, 5])
_A = torch.sparse.FloatTensor(i, v, torch.Size([n, n])).cuda()

A = [k] * r
X = torch.rand(n, n).cuda()
W = nn.Parameter(torch.rand(r, n, d)).cuda()

out = torch.zeros(n, d).cuda()
for i, a in enumerate(A):
    emb_acc =, X)  # (n x n)
    out += torch.matmul(emb_acc, W[i])

I want to be able to compute $o = \sum_{r} W_r A_r X$ and backprop w.r.t W.

Running this snippet on a GPU with around 12GB RAM, the memory gets full at around 7th iteration in the for loop. Is there an efficient way to achieve what I am trying to achieve?

Any help would be appreciated!