Hi, my code has this kind of problem,RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time. But i just have one backward function. Here is my code. keys and values are initialised outside this loop as keys = torch.zeros(10, 512).cuda()
values = (torch.ones(10, 10) / 10).cuda().
for batch_idx, (noise_x, noise_y) in enumerate(noise_loader):
try:
clean_x, clean_y = clean_iter.next()
except:
clean_iter = iter(clean_loader)
clean_x, clean_y = clean_iter.next()
clean_x, clean_y = clean_x.cuda(), clean_y.cuda()
noise_x, noise_y = noise_x.cuda(), noise_y.cuda()
x_all = torch.cat((clean_x, noise_x),0)
optimizer_n.zero_grad()
features_clean, logits_clean = network(clean_x, get_feat=True)
features_noise, logits_noise = network(noise_x, get_feat=True)
print('check')
#keys_temp1, values_temp1 = Mea.module(clean_y, features_clean, F.softmax(logits_clean, dim=1))
keys, values = memory.module(clean_y, features_clean, F.softmax(logits_clean, dim=1), keys ,values)
mem_pred = memory.assimilation(keys, values, features_noise)
#loss_clean = CEloss(logits_clean, clean_y)
loss_ME = CEloss(mem_pred, noise_y)
#Loss_total = loss_clean + loss_ME
#Loss_total.backward()
loss_ME.backward()
optimizer_n.step()