Hello, I try to implement MAML meta learning algorithm for GNN, but i had always an inplace error. you find below my code. Can anyone help me to solve this issue.
def MAML(model,update_step, update_lr,outer_lr):
list_pat = np.load("/home/arahmani/seizure_data/list_10pat.npy")
model = model.to(device)
model.train()
meta_optim = torch.optim.Adam(model.parameters(), lr=outer_lr)
dict_data = {}
for pat in list_pat:
X1_train, X1_val, patient1_trainloader = generate_data_GNN(df_4seiz_only, pat, 2500, "")
dict_data[pat] = (X1_train, X1_val, patient1_trainloader)
losses_q = [0 for _ in range(update_step + 1)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(update_step + 1)]
keep_weight = deepcopy(model.state_dict())
for i in range(10):
fast_weights = OrderedDict()
for data in patient1_trainloader:
data = data.to(device)
support_outp = model(data.x, data.edge_index,data.edge_attr.float(), data.batch)
# 1. run the i-th task and compute loss for k=0
loss = F.cross_entropy(support_outp, data.y)
grad = torch.autograd.grad(loss, model.parameters())
for i,(weight_name, weight) in enumerate(model.named_parameters()):
fast_weights[weight_name] = keep_weight[weight_name] - update_lr* grad[i]
query_train_loader = DataLoader(X1_val, batch_size = len(X1_val))
# this is the loss and accuracy before first update
with torch.no_grad():
for data in query_train_loader:
data = data.to(device)
logits_q = model(data.x, data.edge_index,data.edge_attr.float(), data.batch)
loss_q = F.cross_entropy(logits_q, data.y)
losses_q[0] += loss_q
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, data.y).sum().item()
corrects[0] = corrects[0] + correct
# this is the loss and accuracy after the first update
with torch.no_grad():
model.load_state_dict(fast_weights)
for data in query_train_loader:
data = data.to(device)
logits_q = model(data.x, data.edge_index,data.edge_attr.float(), data.batch)
loss_q = F.cross_entropy(logits_q, data.y)
losses_q[1] += loss_q
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, data.y).sum().item()
corrects[1] = corrects[1] + correct
for k in range(1, update_step):
# 1. run the i-th task and compute loss for k=1~K-1
model.load_state_dict(fast_weights)
for data in patient1_trainloader:
data = data.to(device)
support_outp = model(data.x, data.edge_index,data.edge_attr.float(), data.batch)
# 1. run the i-th task and compute loss for k=0
loss = F.cross_entropy(support_outp, data.y)
grad = torch.autograd.grad(loss, model.parameters(), create_graph=True)
for i,(weight_name, weight) in enumerate(model.named_parameters()):
fast_weights[weight_name] = fast_weights[weight_name] - update_lr* grad[i]
model.load_state_dict(fast_weights)
for data in query_train_loader:
data = data.to(device)
logits_q = model(data.x, data.edge_index,data.edge_attr.float(), data.batch)
loss_q = F.cross_entropy(logits_q, data.y)
losses_q[k+1] += loss_q
with torch.no_grad():
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, data.y).sum().item()
corrects[k+1] = corrects[k+1] + correct
# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / 10
model.load_state_dict(keep_weight)
# optimize theta parameters
meta_optim.zero_grad()
torch.autograd.set_detect_anomaly(True)
loss_q.backward()
# print('meta update')
# for p in self.net.parameters()[:5]:
# print(torch.norm(p).item())
meta_optim.step()
#accs = np.array(corrects) / (querysz * task_num)
return model
model1 = gnn_model.GCN(64)
best_model = MAML(model1, 3,0.00001,0.0001)
RuntimeError Traceback (most recent call last)
/home/arahmani/seizure_detection/new_meta_code.ipynb Cell 8' in <module>
1 model1 = gnn_model.GCN(64)
----> 2 _ = MAML(model1, 3,0.00001,0.0001)
/home/arahmani/seizure_detection/new_meta_code.ipynb Cell 7' in MAML(model, update_step, update_lr, outer_lr)
84 meta_optim.zero_grad()
85 torch.autograd.set_detect_anomaly(True)
---> 86 loss_q.backward()
87 # print('meta update')
88 # for p in self.net.parameters()[:5]:
89 # print(torch.norm(p).item())
90 meta_optim.step()
File ~/.conda/envs/seizure_task/lib/python3.9/site-packages/torch/_tensor.py:307, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
298 if has_torch_function_unary(self):
299 return handle_torch_function(
300 Tensor.backward,
301 (self,),
(...)
305 create_graph=create_graph,
306 inputs=inputs)
--> 307 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File ~/.conda/envs/seizure_task/lib/python3.9/site-packages/torch/autograd/__init__.py:154, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
151 if retain_graph is None:
152 retain_graph = create_graph
--> 154 Variable._execution_engine.run_backward(
155 tensors, grad_tensors_, retain_graph, create_graph, inputs,
156 allow_unreachable=True, accumulate_grad=True)
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 2]], which is output 0 of AsStridedBackward0, is at version 52; expected version 51 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!