Hi, I am attempting to train a system of two related networks in pytorch. However, I cannot seem to perform .backward() operations on the independent losses. When I do I get errors saying either:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
Or (when I add retain_graph=True):
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1536, 10]], which is output 0 of TBackward, is at version 2; expected version 1 instead.
It seems as if pytorch is not treating the separate models independently. Is there any way to fix this?
Here is my code (sorry if the formatting is bad this is my first time posting here)
def main():
args = parse_args()
(student_model,batch_size,num_epochs,slr) = init_student(args.student_config)
(teacher_model,gamma,tlr) = init_teacher(args.teacher_config)
student_model.cuda()
teacher_model.cuda()
print("Models Initialized")
(T, TL, TLIDS, TS, TSL, TSLIDS) = load_cifar(args.data_path)
print("Dataset Loaded")
s_criterion = torch.nn.MSELoss(reduction='sum')
student_optimizer = torch.optim.SGD(student_model.parameters(),lr=slr,momentum=0)
TIN = []
TOUT = []
TR = []
torch.autograd.set_detect_anomaly(True)
centers = torch.randn(10,10).cuda()
print("Beginning Training")
for n in range(num_epochs):
for t in range(int(10000/batch_size)):
"""if (n == 0) and (t == 0):
SP = base_sm([T[t*batch_size:(t+1)*batch_size,:,:,:]])
STP = base_sm([TS[t*batch_size:(t+1)*batch_size,:,:,:]])
centers = calculate_pred_centers(torch.cat((SP,STP),0),torch.cat((TLIDS[t*batch_size:(t+1)*batch_size],TSLIDS[t*batch_size:(t+1)*batch_size]),0))
continue"""
print(str(n) + ' ' + str(t) + ' hi')
SP = student_model([T[t*batch_size:(t+1)*batch_size,:,:,:]])
(closs,targets) = calculate_targets(SP,TLIDS[t*batch_size:(t+1)*batch_size],centers)
loss = s_criterion(SP,targets)
student_optimizer.zero_grad()
loss.backward(retain_graph=True)
student_optimizer.step()
#STP = student_model([TS[t*batch_size:(t+1)*batch_size,:,:,:]])
accuracy = calculate_accuracy(SP,TLIDS[t*batch_size:(t+1)*batch_size],centers)
"""if (n > 0) or (t > 1):
daccuracy = accuracy - prev_accuracy
dloss = closs - prev_loss
reward = (daccuracy / (1 - (torch.sum(accuracy)/10))) - (dloss/(torch.sum(closs)/10))
TR.append(reward)"""
teacher_model = train_teacher(teacher_model,TIN,TOUT,TR,tlr)
pcenters = calculate_pred_centers(torch.cat((SP,SP),0),torch.cat((TLIDS[t*batch_size:(t+1)*batch_size],TLIDS[t*batch_size:(t+1)*batch_size]),0))
teacher_input = (pcenters - centers, torch.cat((accuracy,closs),1))
teacher_output = teacher_model([teacher_input])
centers = centers + teacher_output
TIN.append(teacher_input)
TOUT.append(1)
TR.append(accuracy)
prev_loss, prev_accuracy = closs, accuracy
print("EPOCH " + str(n) + " COMPLETE!")