Hi,
I’m building a sparse feedforward layer for a large input (each input node only connected to 1 or 2 nodes of next layer). I input the sparse weights into a torch.sparse_coo_tensor together with the coordinates of the sparse connections (edges) and use torch.sparse.mm in the forward function. The whole idea of the sparse connections is of course to keep memory usage low. However, during the loss.backward() the GPU memory usage goes up significantly to a level I would expect for calculating gradients of the same layer fully connected (±10Gb).
#Input nodes: 147.931
#Output nodes: 14.670
#Sparse connections: 150997 => estimated gradient memory usage: 603.988 bytes
#Connections if fully connected: 2.170.148.000 => estimated gradient memory usage: 8.680.591.000 bytes
class SparseConnectionsLayer(t.nn.Module):
def __init__(self, bioNet, name = "SCLNN"):
super(SparseConnectionsLayer, self).__init__()
self.edges, self.size = bioNet.getDataForCOO()#edges contain coordinates for sparse connections, size equals t.Size([#nodes1, #nodes2])
self.weight = t.normal(0, 0.1, [self.edges.size()[1]])
self.weight = t.nn.Parameter(self.weight)
self.bias = t.nn.Parameter(t.ones([self.size[1]])*0.1)
def forward(self, x):
adjWeight = t.sparse_coo_tensor(self.edges.to(self.weight.device), self.weight, self.size).to(self.weight.device)
adjWeight = adjWeight.t()
o = t.sparse.mm(adjWeight, x.t()).t()
return o + self.bias
Training loop:
while e <= epochs + warmStart:
errTot = 0
for sample in loader:
optimizer.zero_grad()
x, y = sample
x=t.tensor(x, dtype=t.float, device=device)
y=t.tensor(y, dtype=t.float, device=device)
yp = self.model(x)
loss = lossfn(t.squeeze(yp.float()), y)
norm = sum(p.pow(2.0).sum() for p in self.model.parameters())
loss = loss + weight_decay*norm
loss.backward()
optimizer.step()
t.cuda.empty_cache()
errTot += loss.data.item()
scheduler.step(errTot)
config: torch 1.8.1, CUDA 11.4, Python 3.8.2
Any ideas on how to lower the memory usage?
Thanks!