GPU memory usage during backward pass for sparse parameters

Hi,

I’m building a sparse feedforward layer for a large input (each input node only connected to 1 or 2 nodes of next layer). I input the sparse weights into a torch.sparse_coo_tensor together with the coordinates of the sparse connections (edges) and use torch.sparse.mm in the forward function. The whole idea of the sparse connections is of course to keep memory usage low. However, during the loss.backward() the GPU memory usage goes up significantly to a level I would expect for calculating gradients of the same layer fully connected (±10Gb).

#Input nodes: 147.931
#Output nodes: 14.670
#Sparse connections: 150997 => estimated gradient memory usage: 603.988 bytes
#Connections if fully connected: 2.170.148.000 => estimated gradient memory usage: 8.680.591.000 bytes

class SparseConnectionsLayer(t.nn.Module):
    def __init__(self, bioNet,  name = "SCLNN"):
        super(SparseConnectionsLayer, self).__init__()      
        self.edges, self.size = bioNet.getDataForCOO()#edges contain coordinates for sparse connections, size equals t.Size([#nodes1, #nodes2])
        self.weight = t.normal(0, 0.1, [self.edges.size()[1]])
        self.weight = t.nn.Parameter(self.weight)
        self.bias = t.nn.Parameter(t.ones([self.size[1]])*0.1)

    def forward(self, x):
        adjWeight = t.sparse_coo_tensor(self.edges.to(self.weight.device), self.weight, self.size).to(self.weight.device)   
        adjWeight = adjWeight.t()
        o = t.sparse.mm(adjWeight, x.t()).t()
     
        return o + self.bias

Training loop:

while e <= epochs  + warmStart:
			errTot = 0
			for sample in loader:
				optimizer.zero_grad()
				x, y = sample
				x=t.tensor(x, dtype=t.float, device=device)
				y=t.tensor(y, dtype=t.float, device=device)
				yp = self.model(x)
				loss = lossfn(t.squeeze(yp.float()), y)
				norm = sum(p.pow(2.0).sum() for p in self.model.parameters())
				loss = loss + weight_decay*norm
				loss.backward()
				optimizer.step()
				t.cuda.empty_cache()
				errTot += loss.data.item()

			scheduler.step(errTot)

config: torch 1.8.1, CUDA 11.4, Python 3.8.2

Any ideas on how to lower the memory usage?
Thanks!