Dear all,
I can not figure out how to get rid of the out of memory error:
RuntimeError: CUDA out of memory. Tried to allocate 7.50 MiB (GPU 0; 11.93 GiB total capacity; 5.47 GiB already allocated; 4.88 MiB free; 81.67 MiB cached).
In fact due to the recurrent architecture of my network I have to ‘retain_graph=True’ Otherwise I get the error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I keep running into this error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Here is the main of my function
for epoch in range(300): # again, normally you would NOT do 300 epochs, it is toy data
states = None #torch.empty().to(device)
for idx, image in enumerate(loader):
# Step 1. Remember that Pytorch accumulates gradients.
# We need to clear them out before each instance
# Step 3. Run our forward pass.
tensor = image[0].clone().to(device)
if states is None:
states = prednet.get_initial_states(tensor)
prednet.zero_grad()
# tensor = tensor.reshape(tensor.shape[0],1,tensor.shape[1],tensor.shape[2],tensor.shape[3])
tag_scores, states = prednet(tensor, states)
# Step 4. Compute the loss, gradients, and update the parameters by
# calling optimizer.step()
loss = loss_function(tag_scores, torch.zeros_like(tag_scores))
print(loss)
loss.backward(retain_graph=True)
for state in states:
state.detach()
optimizer.step()
print('1 backward')
torch.cuda.empty_cache()
Here is the forward function:
def forward(self, a, states = None):
r_tm1 = states[:self.nb_layers]
c_tm1 = states[self.nb_layers:2*self.nb_layers]
e_tm1 = states[2*self.nb_layers:3*self.nb_layers]
if self.extrap_start_time is not None:
t = states[-1].copy()
a = torch.switch(t >= self.t_extrap, states[-2], a) # if past self.extrap_start_time, the previous prediction will be treated as the actual
c = []
r = []
e = []
for l in reversed(range(self.nb_layers)):
inputs = [r_tm1[l], e_tm1[l]]
if l < self.nb_layers - 1:
inputs.append(r_up)
inputs = torch.cat(inputs, self.channel_axis)
# print(inputs.shape)
i = self.conv_layers['i'][l](inputs)
f = self.conv_layers['f'][l](inputs)
o = self.conv_layers['o'][l](inputs)
# print('i',torch.isnan(i).any())
# print('f',torch.isnan(f).any())
# print('o',torch.isnan(o).any())
# print('c',torch.isnan(o).any())
# print('c',torch.isnan(self.conv_layers['c'][l](inputs)).any())
_c = f * c_tm1[l] + i * self.conv_layers['c'][l](inputs)
_r = o * self.LSTM_activation(_c)
c.insert(0, _c)
r.insert(0, _r)
if l > 0:
r_up = self.upsample(_r)
for l in range(self.nb_layers):
ahat = self.conv_layers['ahat'][l](r[l])
if l == 0:
value = torch.Tensor([self.pixel_max]).to(device)
ahat = torch.min(ahat, value.expand_as(ahat))
frame_prediction = ahat
# compute errors
e_up = self.error_activation(ahat - a)
e_down = self.error_activation(a - ahat)
e.append(torch.cat((e_up, e_down), dim=self.channel_axis))
if l < self.nb_layers - 1:
a = self.conv_layers['a'][l](e[l])
a = self.pool(a) # target for next layer
if self.output_mode == 'prediction':
output = frame_prediction
else:
for l in range(self.nb_layers):
layer_error = torch.mean(torch.flatten(e[l],start_dim=1), dim=-1, keepdim = True)
if l == 0:
all_error = layer_error
else:
all_error = torch.cat((all_error, layer_error), dim=-1)
if self.output_mode == 'error' and image_n ==0:
output = all_error
output = output.unsqueeze(1)
# elif self.output_mode == 'error':
# all_error = all_error.unsqueeze(1)
# output = torch.cat((output, all_error), dim=1)
else:
output = torch.cat((torch.flatten(frame_prediction, start_dim=1), all_error), dim=-1)
states = r + c + e
if self.extrap_start_time is not None:
states += [frame_prediction, t + 1]
# return output, states
return output, states