I’m having a problem where my network gets through the first epoch of the training set, no problem. But when it tries to calculate the validation or test accuracy, I get an ‘out of memory’ error.
I made a modification to the LSTMCell function in nn/_functions/rnn.py, with the aim of doing a little clustering on the activations within a single node based on each node’s other activations within a batch.
def LSTMCell(input, hidden, w_ih, w_hh,count, b_ih=None, b_hh=None):
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate, sensitivitygate = gates.chunk(5, 1)
batch_size = cellgate.size(0)
a = F.relu(sensitivitygate)
a = torch.unsqueeze(a,2)
a = a.expand(-1, -1, 3*batch_size)
ingate = F.sigmoid(ingate)
forgetgate = F.sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = F.sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
position = F.tanh(cy)
position = torch.cat((position,hx,cx),0)
position = torch.unsqueeze(position,2)
position = position.expand(-1,-1, 3*batch_size)
distance = (position - torch.transpose(position,dim0=0,dim1=2))**2
distance = 1/(1+a*distance[0:batch_size,:,:])
weighted_avg = torch.sum((distance*position[0:batch_size,:,:]),2)
weighted_avg = weighted_avg/torch.sum(distance,2)
hy = outgate * weighted_avg
return hy, cy
The error occurs in the distance = (position - torch.transpose(position,dim0=0,dim1=2))**2
line.
In this line I subtract a transposed matrix from the original (I transposed the batch_size dimension, and a 3rd dimension that I added, to stretch the square into a cube), to find the distance between each activation and each other activation in that batch, for that neuron.
I can imagine that this line is very demanding, but the network seems to be able to manage it for the training data.
Specific Error:
RuntimeError: cuda runtime error (2) : out of memory at c:\anaconda2\conda-bld\pytorch_1519496000060\work\torch\lib\thc\generic/THCStorage.cu:58``
Please help me figure out why I’m getting memory issues when the network gets to the validation set, and how I can resolve this!
It would be enormously appreciated!
Thanks!!