Hello, I am trying to reproduce the code of Kumar’s JODIE model available on the following github GitHub - claws-lab/jodie: A PyTorch implementation of ACM SIGKDD 2019 paper "Predicting Dynamic Embedding Trajectory in Temporal Interaction Networks" with pytorch version 1.7.1. Everything goes fine during training but for model evaluation I get the following error RuntimeError: Function ‘MseLossBackward’ returned nan values in its 0th output. the code used is the following:
def evaluate(data, embedding_dim, state_change, train_proportion, model, optimizer, user_embeddings_dystat, item_embeddings_dystat, user_embeddings_timeseries, item_embeddings_timeseries):
torch.autograd.set_detect_anomaly(True)
df = data.to_numpy()[user2id, user_sequence_id, user_timediffs_sequence, user_previous_itemid_sequence,
item2id, item_sequence_id, item_timediffs_sequence,
timestamp_sequence, feature_sequence, y_true] = load_network(df)
num_interactions = len(user_sequence_id)
num_users = len(user2id)
num_items = len(item2id) + 1
num_features = len(feature_sequence[0])
true_labels_ratio = len(y_true) / (1.0 + sum(y_true))
print(“*** Network statistics: \n %d users\n %d items\n %d interactions\n %d %d true labels ***\n\n” % (num_users, num_items, num_interactions, sum(y_true), len(y_true)))train_end_idx = validation_start_idx = int(num_interactions * train_proportion)
test_start_idx = int(num_interactions * (train_proportion + 0.2))
test_end_idx = int(num_interactions * (train_proportion + 0.4))timespan = timestamp_sequence[-1] - timestamp_sequence[0]
tbatch_timespan = timespan / 500model = JODIE(num_features, num_users, num_items, embedding_dim)
weight = torch.Tensor([1, true_labels_ratio])
crossEntropyLoss = nn.CrossEntropyLoss(weight = weight)
MSELoss = nn.MSELoss()learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr = learning_rate, weight_decay = 1e-5)set_embeddings_training_end(user_embeddings_dystat, item_embeddings_dystat, user_embeddings_timeseries, item_embeddings_timeseries, user_sequence_id, item_sequence_id, train_end_idx)
item_embeddings = item_embeddings_dystat[:, :embedding_dim]
item_embeddings = item_embeddings.clone()
item_embeddings_static = item_embeddings_dystat[:, embedding_dim:]
item_embeddings_static = item_embeddings_static.clone()user_embeddings = user_embeddings_dystat[:, :embedding_dim]
user_embeddings = user_embeddings.clone()
user_embeddings_static = user_embeddings_dystat[:, embedding_dim:]
user_embeddings_static = user_embeddings_static.clone()validation_predicted_y =
test_predicted_y =
validation_true_y =
test_true_y =tbatch_start_time = None
loss = 0with trange(train_end_idx, test_end_idx) as progress_bar:
for j in progress_bar:
progress_bar.set_description(“%dth interaction for validation and testing” % j)userid = user_sequence_id[j] itemid = item_sequence_id[j] feature = feature_sequence[j] user_timediff = user_timediffs_sequence[j] item_timediff = item_timediffs_sequence[j] timestamp = timestamp_sequence[j] if not tbatch_start_time: tbatch_start_time = timestamp itemid_previous = user_previous_itemid_sequence[j] user_embedding_input = user_embeddings[torch.LongTensor([userid])] user_embedding_static_input = user_embeddings_static[torch.LongTensor([userid])] item_embedding_input = item_embeddings[torch.LongTensor([itemid])] item_embedding_static_input = item_embeddings_static[torch.LongTensor([itemid])] feature_tensor = Variable(torch.Tensor([feature])).unsqueeze(0) user_timediffs_tensor = Variable(torch.Tensor([user_timediff])).unsqueeze(0) item_timediffs_tensor = Variable(torch.Tensor([item_timediff])).unsqueeze(0) item_embedding_previous = item_embeddings[torch.LongTensor([itemid_previous])] user_projected_embedding = model.forward(user_embedding_input, item_embedding_previous, timediffs = user_timediffs_tensor, features = feature_tensor, select = "project") user_item_embedding = torch.cat([user_projected_embedding, item_embedding_previous, item_embeddings_static[torch.LongTensor([itemid_previous])], user_embedding_static_input], dim = 1) predicted_item_embedding = model.predicted_item_embedding(user_item_embedding) loss = loss + MSELoss(predicted_item_embedding, torch.cat([item_embedding_input, item_embedding_static_input], dim = 1).clone()) user_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs = user_timediffs_tensor, features = feature_tensor.reshape(1,-1), select = "user_update") item_embedding_output = model.forward(user_embedding_input, item_embedding_input, timediffs = item_timediffs_tensor, features = feature_tensor.reshape(1,-1), select = "item_update") item_embeddings[itemid, :] = item_embedding_output.squeeze(0) user_embeddings[userid, :] = user_embedding_output.squeeze(0) user_embeddings_timeseries[j, :] = user_embedding_output.squeeze(0) item_embeddings_timeseries[j, :] = item_embedding_output.squeeze(0) loss = loss + MSELoss(item_embedding_output, item_embedding_input.clone()) loss = loss + MSELoss(user_embedding_output, user_embedding_input.clone()) if state_change: loss = loss + calculate_state_prediction_loss(model, [j], user_embeddings_timeseries, y_true, crossEntropyLoss) if timestamp - tbatch_start_time > tbatch_timespan: print("ok") tbatch_start_time = timestamp loss.backward() optimizer.step() optimizer.zero_grad() loss = 0 item_embeddings.detach_() user_embeddings.detach_() item_embeddings_timeseries.detach_() user_embeddings_timeseries.detach_() prob = model.predict_label(user_embedding_output) if j < test_start_idx: validation_predicted_y.extend(prob.data.cpu().numpy()) validation_true_y.extend([y_true[j]]) else: test_predicted_y.extend(prob.data.cpu().numpy()) test_true_y.extend([y_true[j]])
validation_predicted_y = np.array(validation_predicted_y)
test_predicted_y = np.array(test_predicted_y)performance_dict = dict()
auc = roc_auc_score(validation_true_y, validation_predicted_y[:, 1])
performance_dict[“validation”] = [auc]auc = roc_auc_score(test_true_y, test_predicted_y[:, 1])
performance_dict[“test”] = [auc]return performance_dict
Can you help me please