When I am running below code I am getting error as
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [64, 1]], which is output 0 of TBackward, is at version 65; expected version 64 instead
import torch
# import numpy as np
from copy import deepcopy
from torch.autograd import Variable
from torch.nn import functional as F
from collections import OrderedDict
from embeddings import item, user
torch.autograd.set_detect_anomaly(True)
class user_preference_estimator(torch.nn.Module):
def __init__(self, config):
super(user_preference_estimator, self).__init__()
self.embedding_dim = config['embedding_dim']
self.fc1_in_dim = config['embedding_dim'] * 8
self.fc2_in_dim = config['first_fc_hidden_dim']
self.fc2_out_dim = config['second_fc_hidden_dim']
self.use_cuda = config['use_cuda']
self.item_emb = item(config)
self.user_emb = user(config)
self.fc1 = torch.nn.Linear(self.fc1_in_dim, self.fc2_in_dim)
self.fc2 = torch.nn.Linear(self.fc2_in_dim, self.fc2_out_dim)
self.linear_out = torch.nn.Linear(self.fc2_out_dim, 1)
def forward(self, x, training = True):
rate_idx = Variable(x[:, 0], requires_grad=False)
genre_idx = Variable(x[:, 1:26], requires_grad=False)
director_idx = Variable(x[:, 26:2212], requires_grad=False)
actor_idx = Variable(x[:, 2212:10242], requires_grad=False)
gender_idx = Variable(x[:, 10242], requires_grad=False)
age_idx = Variable(x[:, 10243], requires_grad=False)
occupation_idx = Variable(x[:, 10244], requires_grad=False)
area_idx = Variable(x[:, 10245], requires_grad=False)
"""
- You can check the embedding class for complete detail but just to summarize
- We are using embeddings for single feature and for list we are using liner layer
- Then we are concatenating it
- in this method we are concatenating item and user embedding as well, then then we are passing it to the fully connected layers
"""
item_emb = self.item_emb(rate_idx, genre_idx, director_idx, actor_idx)
user_emb = self.user_emb(gender_idx, age_idx, occupation_idx, area_idx)
x = torch.cat((item_emb, user_emb), 1)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
x = F.relu(x)
return self.linear_out(x)
class MeLU(torch.nn.Module):
def __init__(self, config):
super(MeLU, self).__init__()
self.use_cuda = config['use_cuda']
self.model = user_preference_estimator(config)
self.local_lr = config['local_lr']
self.store_parameters()
self.meta_optim = torch.optim.Adam(self.model.parameters(), lr=config['lr'])
self.local_update_target_weight_name = ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'linear_out.weight', 'linear_out.bias']
def store_parameters(self):
self.keep_weight = deepcopy(self.model.state_dict())
self.weight_name = list(self.keep_weight.keys())
self.weight_len = len(self.keep_weight)
self.fast_weights = OrderedDict()
def forward(self, support_set_x, support_set_y, query_set_x, num_local_update):
for idx in range(num_local_update):
if idx > 0:
self.model.load_state_dict(self.fast_weights)
weight_for_local_update = list(self.model.state_dict().values())
support_set_y_pred = self.model(support_set_x)
loss = F.mse_loss(support_set_y_pred, support_set_y.view(-1, 1))
self.model.zero_grad()
grad = torch.autograd.grad(loss, self.model.parameters(), create_graph=True)
# local update
for i in range(self.weight_len):
if self.weight_name[i] in self.local_update_target_weight_name:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i] - self.local_lr * grad[i]
else:
self.fast_weights[self.weight_name[i]] = weight_for_local_update[i]
self.model.load_state_dict(self.fast_weights)
query_set_y_pred = self.model(query_set_x)
self.model.load_state_dict(self.keep_weight)
return query_set_y_pred
def global_update(self, support_set_xs, support_set_ys, query_set_xs, query_set_ys, num_local_update):
batch_sz = len(support_set_xs)
losses_q = []
if self.use_cuda:
for i in range(batch_sz):
support_set_xs[i] = support_set_xs[i].cuda()
support_set_ys[i] = support_set_ys[i].cuda()
query_set_xs[i] = query_set_xs[i].cuda()
query_set_ys[i] = query_set_ys[i].cuda()
for i in range(batch_sz):
query_set_y_pred = self.forward(support_set_xs[i], support_set_ys[i], query_set_xs[i], num_local_update)
loss_q = F.mse_loss(query_set_y_pred, query_set_ys[i].view(-1, 1))
losses_q.append(loss_q)
losses_q = torch.stack(losses_q).mean(0)
self.meta_optim.zero_grad()
losses_q.backward()
self.meta_optim.step()
self.store_parameters()
return
When I am removing below lines from forward, the error is gone
self.model.load_state_dict(self.fast_weights)
self.model.load_state_dict(self.keep_weight)