Hi, I defined a model (albeit somewhat naively) with many nested for-loops. The forward pass works fine for all batch sizes, but for batch sizes > 1, the backward pass explodes the memory footprint and the python kernel crashes the system (I only have 8GB on this particular system). To give you an idea of the discrepancy, the batch size 1 does takes up less than 2% of the total system memory (part of the network is not shown), but batch size 2 takes over 95% (crashes before it gets higher). I was wondering if there is some mistake I’m making? I’m pretty sure the offending code is below:
def prim_to_uhat(self, x):
u_hat = Variable(torch.FloatTensor(x.size(0), 32, 6, 6, 10, 16))
for q in range(x.size(0)):
for i in range(0, 10):
for j in range(0, 32):
for k in range(0, 6):
for l in range(0, 6):
u_hat[:,j,k,l,i,:] = self.wij[i][j][k][l](x[:,j,k,l,:])
return u_hat