Hi, i get the error “Trying to backward through the graph a second time, but the buffers have already been freed.” even though i use hidden = hidden.detach() after loss.backward().
I split a long sequence into many batches, where each batch is a slice of the long sequence at a time-step with a given slice-length.
I then run the slice through my network and get output and hidden. However i dont want to BPTT each timestep but rather at some interval, lets say each 20 time steps.
I then run through 20 batches and update output, hidden each step. At step 20 i calculate loss, and do loss.backward(). To clear computational graph for next 20 steps i detach hidden states using hidden = hidden.detach(). However at next loss.backward() i get the error.
Is there a way to completely delete the computational graph?
My training loop:
def train(epoch):
print("Training Initiated!")
model.train()
losses = []
#Run through all sequences:
for step, (data, target) in enumerate(train_set_all):
print("Sequence #%d"%(step))
X = data
y = target
#MSELoss:
y = y[:,0:1,:]
y = y.view(-1,3)
max_seq_size = 4000
#If track is less than max_seq_size samples, skip track
if X.size(2) < max_seq_size:
print("Sequence too short, skipped!")
continue
stride_length = 8 #with stride 8, and 3 poolings 1 sample in output matches 8 samples in input
batch_size = 400
#Clip track
X = X[:,:,:max_seq_size] # Clip data to max_seq_size
#Zero-pad start of sequence so first prediction only uses first 8 samples for classification
m = nn.ConstantPad1d((batch_size-stride_length, 0), 0)
X = m(X)
#Generate batches from sequence:
batches = []
for start_pos in range(0, max_seq_size, stride_length):
end_pos = start_pos + batch_size
batch = np.copy(X[:,:,start_pos:end_pos])
batches.append(torch.from_numpy(batch))
#Initialize hidden state once for each sequence:
hidden = Net.init_hidden(model)
#How often to BPTT
TBTT_step = 10
# The following two lists hold output of the model, and
# the target.
predicted, true = [], []
#Run one batch at a time:
for idx, nbatch in enumerate(batches):
start_time = time.time()
print("Batch number #%d" %(idx))
#print(nbatch.size())
#print(nbatch)
output, hidden = model(nbatch, hidden, batch_size)
predicted.append(output[0])
true.append(y[0])
#Update optimizer and calculate loss every TBTT_step batches: (hidden gets updates each sample)
if idx % TBTT_step == 0:
#NLL
#loss = criterion(output, y[:,nbatch*batch_size+batch_size-1:nbatch*batch_size+batch_size].reshape(1).long())
#MSE
loss = torch.mean(criterion(torch.cat(predicted),torch.cat(true)))
losses.append(loss.data[0].item())
#print(loss.item())
#print(loss)
# Reset gradient
optimizer.zero_grad()
loss.backward()
optimizer.step()
#Release hidden
hidden = hidden.detach()
elapsed_time = time.time() - start_time
print(elapsed_time)
if step > 500: #Only run through 500 sequences(tracks)
return losses